-
Notifications
You must be signed in to change notification settings - Fork 9
/
se.py
67 lines (46 loc) · 1.85 KB
/
se.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from tensorflow.keras.layers import GlobalAveragePooling2D, Reshape, Dense, Multiply, Add, Permute, Conv2D
from tensorflow.keras import backend as K
def squeeze_excite_block(input, ratio=16):
''' Create a channel-wise squeeze-excite block
Args:
input: input tensor
filters: number of output filters
Returns: a keras tensor
References
- [Squeeze and Excitation Networks](https://arxiv.org/abs/1709.01507)
'''
init = input
filters = init.shape[-1]
se_shape = (1, 1, filters)
se = GlobalAveragePooling2D()(init)
se = Reshape(se_shape)(se)
se = Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se)
se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se)
x = Multiply()([init, se])
return x
def spatial_squeeze_excite_block(input):
''' Create a spatial squeeze-excite block
Args:
input: input tensor
Returns: a keras tensor
References
- [Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks](https://arxiv.org/abs/1803.02579)
'''
se = Conv2D(1, (1, 1), activation='sigmoid', use_bias=False,
kernel_initializer='he_normal')(input)
x = multiply([input, se])
return x
def channel_spatial_squeeze_excite(input, ratio=16):
''' Create a spatial squeeze-excite block
Args:
input: input tensor
filters: number of output filters
Returns: a keras tensor
References
- [Squeeze and Excitation Networks](https://arxiv.org/abs/1709.01507)
- [Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks](https://arxiv.org/abs/1803.02579)
'''
cse = squeeze_excite_block(input, ratio)
sse = spatial_squeeze_excite_block(input)
x = Add([cse, sse])
return x