-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
deeplab.py
297 lines (251 loc) · 11.3 KB
/
deeplab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddleseg.cvlibs import manager
from paddleseg.models import layers
from paddleseg.utils import utils
__all__ = ['DeepLabV3P', 'DeepLabV3']
@manager.MODELS.add_component
class DeepLabV3P(nn.Layer):
"""
The DeepLabV3Plus implementation based on PaddlePaddle.
The original article refers to
Liang-Chieh Chen, et, al. "Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation"
(https://arxiv.org/abs/1802.02611)
Args:
num_classes (int): The unique number of target classes.
backbone (paddle.nn.Layer): Backbone network, currently support Resnet50_vd/Resnet101_vd/Xception65.
backbone_indices (tuple, optional): Two values in the tuple indicate the indices of output of backbone.
Default: (0, 3).
aspp_ratios (tuple, optional): The dilation rate using in ASSP module.
If output_stride=16, aspp_ratios should be set as (1, 6, 12, 18).
If output_stride=8, aspp_ratios is (1, 12, 24, 36).
Default: (1, 6, 12, 18).
aspp_out_channels (int, optional): The output channels of ASPP module. Default: 256.
align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even,
e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
pretrained (str, optional): The path or url of pretrained model. Default: None.
data_format(str, optional): Data format that specifies the layout of input. It can be "NCHW" or "NHWC". Default: "NCHW".
"""
def __init__(self,
num_classes,
backbone,
backbone_indices=(0, 3),
aspp_ratios=(1, 6, 12, 18),
aspp_out_channels=256,
align_corners=False,
pretrained=None,
data_format="NCHW"):
super().__init__()
self.backbone = backbone
backbone_channels = [
backbone.feat_channels[i] for i in backbone_indices
]
self.head = DeepLabV3PHead(num_classes,
backbone_indices,
backbone_channels,
aspp_ratios,
aspp_out_channels,
align_corners,
data_format=data_format)
self.align_corners = align_corners
self.pretrained = pretrained
self.data_format = data_format
self.init_weight()
def forward(self, x):
feat_list = self.backbone(x)
logit_list = self.head(feat_list)
if self.data_format == 'NCHW':
ori_shape = x.shape[2:]
else:
ori_shape = x.shape[1:3]
return [
F.interpolate(logit,
ori_shape,
mode='bilinear',
align_corners=self.align_corners,
data_format=self.data_format) for logit in logit_list
]
def init_weight(self):
if self.pretrained is not None:
utils.load_entire_model(self, self.pretrained)
class DeepLabV3PHead(nn.Layer):
"""
The DeepLabV3PHead implementation based on PaddlePaddle.
Args:
num_classes (int): The unique number of target classes.
backbone_indices (tuple): Two values in the tuple indicate the indices of output of backbone.
the first index will be taken as a low-level feature in Decoder component;
the second one will be taken as input of ASPP component.
Usually backbone consists of four downsampling stage, and return an output of
each stage. If we set it as (0, 3), it means taking feature map of the first
stage in backbone as low-level feature used in Decoder, and feature map of the fourth
stage as input of ASPP.
backbone_channels (tuple): The same length with "backbone_indices". It indicates the channels of corresponding index.
aspp_ratios (tuple): The dilation rates using in ASSP module.
aspp_out_channels (int): The output channels of ASPP module.
align_corners (bool): An argument of F.interpolate. It should be set to False when the output size of feature
is even, e.g. 1024x512, otherwise it is True, e.g. 769x769.
data_format(str, optional): Data format that specifies the layout of input. It can be "NCHW" or "NHWC". Default: "NCHW".
"""
def __init__(self,
num_classes,
backbone_indices,
backbone_channels,
aspp_ratios,
aspp_out_channels,
align_corners,
data_format='NCHW'):
super().__init__()
self.aspp = layers.ASPPModule(aspp_ratios,
backbone_channels[1],
aspp_out_channels,
align_corners,
use_sep_conv=True,
image_pooling=True,
data_format=data_format)
self.decoder = Decoder(num_classes,
backbone_channels[0],
align_corners,
data_format=data_format)
self.backbone_indices = backbone_indices
def forward(self, feat_list):
logit_list = []
low_level_feat = feat_list[self.backbone_indices[0]]
x = feat_list[self.backbone_indices[1]]
x = self.aspp(x)
logit = self.decoder(x, low_level_feat)
logit_list.append(logit)
return logit_list
@manager.MODELS.add_component
class DeepLabV3(nn.Layer):
"""
The DeepLabV3 implementation based on PaddlePaddle.
The original article refers to
Liang-Chieh Chen, et, al. "Rethinking Atrous Convolution for Semantic Image Segmentation"
(https://arxiv.org/pdf/1706.05587.pdf).
Args:
Please Refer to DeepLabV3P above.
"""
def __init__(self,
num_classes,
backbone,
backbone_indices=(3, ),
aspp_ratios=(1, 6, 12, 18),
aspp_out_channels=256,
align_corners=False,
pretrained=None):
super().__init__()
self.backbone = backbone
backbone_channels = [
backbone.feat_channels[i] for i in backbone_indices
]
self.head = DeepLabV3Head(num_classes, backbone_indices,
backbone_channels, aspp_ratios,
aspp_out_channels, align_corners)
self.align_corners = align_corners
self.pretrained = pretrained
self.init_weight()
def forward(self, x):
feat_list = self.backbone(x)
logit_list = self.head(feat_list)
return [
F.interpolate(logit,
x.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
for logit in logit_list
]
def init_weight(self):
if self.pretrained is not None:
utils.load_entire_model(self, self.pretrained)
class DeepLabV3Head(nn.Layer):
"""
The DeepLabV3Head implementation based on PaddlePaddle.
Args:
Please Refer to DeepLabV3PHead above.
"""
def __init__(self, num_classes, backbone_indices, backbone_channels,
aspp_ratios, aspp_out_channels, align_corners):
super().__init__()
self.aspp = layers.ASPPModule(aspp_ratios,
backbone_channels[0],
aspp_out_channels,
align_corners,
use_sep_conv=False,
image_pooling=True)
self.cls = nn.Conv2D(in_channels=aspp_out_channels,
out_channels=num_classes,
kernel_size=1)
self.backbone_indices = backbone_indices
def forward(self, feat_list):
logit_list = []
x = feat_list[self.backbone_indices[0]]
x = self.aspp(x)
logit = self.cls(x)
logit_list.append(logit)
return logit_list
class Decoder(nn.Layer):
"""
Decoder module of DeepLabV3P model
Args:
num_classes (int): The number of classes.
in_channels (int): The number of input channels in decoder module.
"""
def __init__(self,
num_classes,
in_channels,
align_corners,
data_format='NCHW'):
super(Decoder, self).__init__()
self.data_format = data_format
self.conv_bn_relu1 = layers.ConvBNReLU(in_channels=in_channels,
out_channels=48,
kernel_size=1,
data_format=data_format)
self.conv_bn_relu2 = layers.SeparableConvBNReLU(in_channels=304,
out_channels=256,
kernel_size=3,
padding=1,
data_format=data_format)
self.conv_bn_relu3 = layers.SeparableConvBNReLU(in_channels=256,
out_channels=256,
kernel_size=3,
padding=1,
data_format=data_format)
self.conv = nn.Conv2D(in_channels=256,
out_channels=num_classes,
kernel_size=1,
data_format=data_format)
self.align_corners = align_corners
def forward(self, x, low_level_feat):
low_level_feat = self.conv_bn_relu1(low_level_feat)
if self.data_format == 'NCHW':
low_level_shape = low_level_feat.shape[-2:]
axis = 1
else:
low_level_shape = low_level_feat.shape[1:3]
axis = -1
x = F.interpolate(x,
low_level_shape,
mode='bilinear',
align_corners=self.align_corners,
data_format=self.data_format)
x = paddle.concat([x, low_level_feat], axis=axis)
x = self.conv_bn_relu2(x)
x = self.conv_bn_relu3(x)
x = self.conv(x)
return x