-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
stdc.py
422 lines (381 loc) · 15.8 KB
/
stdc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
# Copyright (c) OpenMMLab. All rights reserved.
"""Modified from https://github.com/MichaelFan01/STDC-Seg."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmseg.ops import resize
from ..builder import BACKBONES, build_backbone
from .bisenetv1 import AttentionRefinementModule
class STDCModule(BaseModule):
"""STDCModule.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels before scaling.
stride (int): The number of stride for the first conv layer.
norm_cfg (dict): Config dict for normalization layer. Default: None.
act_cfg (dict): The activation config for conv layers.
num_convs (int): Numbers of conv layers.
fusion_type (str): Type of fusion operation. Default: 'add'.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
stride,
norm_cfg=None,
act_cfg=None,
num_convs=4,
fusion_type='add',
init_cfg=None):
super(STDCModule, self).__init__(init_cfg=init_cfg)
assert num_convs > 1
assert fusion_type in ['add', 'cat']
self.stride = stride
self.with_downsample = True if self.stride == 2 else False
self.fusion_type = fusion_type
self.layers = ModuleList()
conv_0 = ConvModule(
in_channels, out_channels // 2, kernel_size=1, norm_cfg=norm_cfg)
if self.with_downsample:
self.downsample = ConvModule(
out_channels // 2,
out_channels // 2,
kernel_size=3,
stride=2,
padding=1,
groups=out_channels // 2,
norm_cfg=norm_cfg,
act_cfg=None)
if self.fusion_type == 'add':
self.layers.append(nn.Sequential(conv_0, self.downsample))
self.skip = Sequential(
ConvModule(
in_channels,
in_channels,
kernel_size=3,
stride=2,
padding=1,
groups=in_channels,
norm_cfg=norm_cfg,
act_cfg=None),
ConvModule(
in_channels,
out_channels,
1,
norm_cfg=norm_cfg,
act_cfg=None))
else:
self.layers.append(conv_0)
self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
else:
self.layers.append(conv_0)
for i in range(1, num_convs):
out_factor = 2**(i + 1) if i != num_convs - 1 else 2**i
self.layers.append(
ConvModule(
out_channels // 2**i,
out_channels // out_factor,
kernel_size=3,
stride=1,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
def forward(self, inputs):
if self.fusion_type == 'add':
out = self.forward_add(inputs)
else:
out = self.forward_cat(inputs)
return out
def forward_add(self, inputs):
layer_outputs = []
x = inputs.clone()
for layer in self.layers:
x = layer(x)
layer_outputs.append(x)
if self.with_downsample:
inputs = self.skip(inputs)
return torch.cat(layer_outputs, dim=1) + inputs
def forward_cat(self, inputs):
x0 = self.layers[0](inputs)
layer_outputs = [x0]
for i, layer in enumerate(self.layers[1:]):
if i == 0:
if self.with_downsample:
x = layer(self.downsample(x0))
else:
x = layer(x0)
else:
x = layer(x)
layer_outputs.append(x)
if self.with_downsample:
layer_outputs[0] = self.skip(x0)
return torch.cat(layer_outputs, dim=1)
class FeatureFusionModule(BaseModule):
"""Feature Fusion Module. This module is different from FeatureFusionModule
in BiSeNetV1. It uses two ConvModules in `self.attention` whose inter
channel number is calculated by given `scale_factor`, while
FeatureFusionModule in BiSeNetV1 only uses one ConvModule in
`self.conv_atten`.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
scale_factor (int): The number of channel scale factor.
Default: 4.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): The activation config for conv layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
scale_factor=4,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super(FeatureFusionModule, self).__init__(init_cfg=init_cfg)
channels = out_channels // scale_factor
self.conv0 = ConvModule(
in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
self.attention = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
ConvModule(
out_channels,
channels,
1,
norm_cfg=None,
bias=False,
act_cfg=act_cfg),
ConvModule(
channels,
out_channels,
1,
norm_cfg=None,
bias=False,
act_cfg=None), nn.Sigmoid())
def forward(self, spatial_inputs, context_inputs):
inputs = torch.cat([spatial_inputs, context_inputs], dim=1)
x = self.conv0(inputs)
attn = self.attention(x)
x_attn = x * attn
return x_attn + x
@BACKBONES.register_module()
class STDCNet(BaseModule):
"""This backbone is the implementation of `Rethinking BiSeNet For Real-time
Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.
Args:
stdc_type (int): The type of backbone structure,
`STDCNet1` and`STDCNet2` denotes two main backbones in paper,
whose FLOPs is 813M and 1446M, respectively.
in_channels (int): The num of input_channels.
channels (tuple[int]): The output channels for each stage.
bottleneck_type (str): The type of STDC Module type, the value must
be 'add' or 'cat'.
norm_cfg (dict): Config dict for normalization layer.
act_cfg (dict): The activation config for conv layers.
num_convs (int): Numbers of conv layer at each STDC Module.
Default: 4.
with_final_conv (bool): Whether add a conv layer at the Module output.
Default: True.
pretrained (str, optional): Model pretrained path. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Example:
>>> import torch
>>> stdc_type = 'STDCNet1'
>>> in_channels = 3
>>> channels = (32, 64, 256, 512, 1024)
>>> bottleneck_type = 'cat'
>>> inputs = torch.rand(1, 3, 1024, 2048)
>>> self = STDCNet(stdc_type, in_channels,
... channels, bottleneck_type).eval()
>>> outputs = self.forward(inputs)
>>> for i in range(len(outputs)):
... print(f'outputs[{i}].shape = {outputs[i].shape}')
outputs[0].shape = torch.Size([1, 256, 128, 256])
outputs[1].shape = torch.Size([1, 512, 64, 128])
outputs[2].shape = torch.Size([1, 1024, 32, 64])
"""
arch_settings = {
'STDCNet1': [(2, 1), (2, 1), (2, 1)],
'STDCNet2': [(2, 1, 1, 1), (2, 1, 1, 1, 1), (2, 1, 1)]
}
def __init__(self,
stdc_type,
in_channels,
channels,
bottleneck_type,
norm_cfg,
act_cfg,
num_convs=4,
with_final_conv=False,
pretrained=None,
init_cfg=None):
super(STDCNet, self).__init__(init_cfg=init_cfg)
assert stdc_type in self.arch_settings, \
f'invalid structure {stdc_type} for STDCNet.'
assert bottleneck_type in ['add', 'cat'],\
f'bottleneck_type must be `add` or `cat`, got {bottleneck_type}'
assert len(channels) == 5,\
f'invalid channels length {len(channels)} for STDCNet.'
self.in_channels = in_channels
self.channels = channels
self.stage_strides = self.arch_settings[stdc_type]
self.prtrained = pretrained
self.num_convs = num_convs
self.with_final_conv = with_final_conv
self.stages = ModuleList([
ConvModule(
self.in_channels,
self.channels[0],
kernel_size=3,
stride=2,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
self.channels[0],
self.channels[1],
kernel_size=3,
stride=2,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
])
# `self.num_shallow_features` is the number of shallow modules in
# `STDCNet`, which is noted as `Stage1` and `Stage2` in original paper.
# They are both not used for following modules like Attention
# Refinement Module and Feature Fusion Module.
# Thus they would be cut from `outs`. Please refer to Figure 4
# of original paper for more details.
self.num_shallow_features = len(self.stages)
for strides in self.stage_strides:
idx = len(self.stages) - 1
self.stages.append(
self._make_stage(self.channels[idx], self.channels[idx + 1],
strides, norm_cfg, act_cfg, bottleneck_type))
# After appending, `self.stages` is a ModuleList including several
# shallow modules and STDCModules.
# (len(self.stages) ==
# self.num_shallow_features + len(self.stage_strides))
if self.with_final_conv:
self.final_conv = ConvModule(
self.channels[-1],
max(1024, self.channels[-1]),
1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def _make_stage(self, in_channels, out_channels, strides, norm_cfg,
act_cfg, bottleneck_type):
layers = []
for i, stride in enumerate(strides):
layers.append(
STDCModule(
in_channels if i == 0 else out_channels,
out_channels,
stride,
norm_cfg,
act_cfg,
num_convs=self.num_convs,
fusion_type=bottleneck_type))
return Sequential(*layers)
def forward(self, x):
outs = []
for stage in self.stages:
x = stage(x)
outs.append(x)
if self.with_final_conv:
outs[-1] = self.final_conv(outs[-1])
outs = outs[self.num_shallow_features:]
return tuple(outs)
@BACKBONES.register_module()
class STDCContextPathNet(BaseModule):
"""STDCNet with Context Path. The `outs` below is a list of three feature
maps from deep to shallow, whose height and width is from small to big,
respectively. The biggest feature map of `outs` is outputted for
`STDCHead`, where Detail Loss would be calculated by Detail Ground-truth.
The other two feature maps are used for Attention Refinement Module,
respectively. Besides, the biggest feature map of `outs` and the last
output of Attention Refinement Module are concatenated for Feature Fusion
Module. Then, this fusion feature map `feat_fuse` would be outputted for
`decode_head`. More details please refer to Figure 4 of original paper.
Args:
backbone_cfg (dict): Config dict for stdc backbone.
last_in_channels (tuple(int)), The number of channels of last
two feature maps from stdc backbone. Default: (1024, 512).
out_channels (int): The channels of output feature maps.
Default: 128.
ffm_cfg (dict): Config dict for Feature Fusion Module. Default:
`dict(in_channels=512, out_channels=256, scale_factor=4)`.
upsample_mode (str): Algorithm used for upsampling:
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
``'trilinear'``. Default: ``'nearest'``.
align_corners (str): align_corners argument of F.interpolate. It
must be `None` if upsample_mode is ``'nearest'``. Default: None.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Return:
outputs (tuple): The tuple of list of output feature map for
auxiliary heads and decoder head.
"""
def __init__(self,
backbone_cfg,
last_in_channels=(1024, 512),
out_channels=128,
ffm_cfg=dict(
in_channels=512, out_channels=256, scale_factor=4),
upsample_mode='nearest',
align_corners=None,
norm_cfg=dict(type='BN'),
init_cfg=None):
super(STDCContextPathNet, self).__init__(init_cfg=init_cfg)
self.backbone = build_backbone(backbone_cfg)
self.arms = ModuleList()
self.convs = ModuleList()
for channels in last_in_channels:
self.arms.append(AttentionRefinementModule(channels, out_channels))
self.convs.append(
ConvModule(
out_channels,
out_channels,
3,
padding=1,
norm_cfg=norm_cfg))
self.conv_avg = ConvModule(
last_in_channels[0], out_channels, 1, norm_cfg=norm_cfg)
self.ffm = FeatureFusionModule(**ffm_cfg)
self.upsample_mode = upsample_mode
self.align_corners = align_corners
def forward(self, x):
outs = list(self.backbone(x))
avg = F.adaptive_avg_pool2d(outs[-1], 1)
avg_feat = self.conv_avg(avg)
feature_up = resize(
avg_feat,
size=outs[-1].shape[2:],
mode=self.upsample_mode,
align_corners=self.align_corners)
arms_out = []
for i in range(len(self.arms)):
x_arm = self.arms[i](outs[len(outs) - 1 - i]) + feature_up
feature_up = resize(
x_arm,
size=outs[len(outs) - 1 - i - 1].shape[2:],
mode=self.upsample_mode,
align_corners=self.align_corners)
feature_up = self.convs[i](feature_up)
arms_out.append(feature_up)
feat_fuse = self.ffm(outs[0], arms_out[1])
# The `outputs` has four feature maps.
# `outs[0]` is outputted for `STDCHead` auxiliary head.
# Two feature maps of `arms_out` are outputted for auxiliary head.
# `feat_fuse` is outputted for decoder head.
outputs = [outs[0]] + list(arms_out) + [feat_fuse]
return tuple(outputs)