-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
66 lines (55 loc) · 1.91 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
if __name__ == '__main__':
import paddle
from amazing_cv.attention import *
from amazing_cv.mlp import *
# input = paddle.randn([50, 49, 512])
# a2 = DoubleAttention(512, 128, 128, True)
# output = a2(input)
# print(output.shape)
# triplet = TripletAttention()
# output = triplet(input)
# print(output.shape)
# eca = ECAAttention(kernel_size=3)
# output = eca(input)
# print(output.shape)
# ssa = SimplifiedScaledDotProductAttention(d_model=512, h=8)
# output = ssa(input, input, input)
# print(output.shape)
# input = paddle.randn([50, 512, 7, 7])
# se = ShuffleAttention(channel=512, G=8)
# output = se(input)
# print(output.shape)
# input = paddle.randn([50, 49, 512])
# sa = ScaledDotProductAttention(d_model=512, d_k=512, d_v=512, h=8)
# output = sa(input, input, input)
# print(output.shape)
# input = paddle.randn([50, 512, 7, 7])
# sge = SpatialGroupEnhance(groups=8)
# output = sge(input)
# print(output.shape)
# input = paddle.randn([50, 512, 7, 7])
# se = SEAttention(channel=512, reduction=8)
# output = se(input)
# print(output.shape)
# input = paddle.randn([50, 512, 7, 7])
# psa = PSA(channel=512, reduction=8)
# output = psa(input)
# a = output.reshape([-1]).sum()
# a.backward()
# print(output.shape)
# input = paddle.randn([50, 49, 512])
# sa = MobileViTV2Attention(d_model=512)
# output = sa(input)
# print(output.shape)
input = paddle.randn([50, 49, 512])
sa = MUSEAttention(d_model=512, d_k=512, d_v=512, h=8)
output = sa(input, input, input)
print(output.shape)
# num_tokens = 10000
# bs = 50
# len_sen = 49
# num_layers = 6
# input = paddle.randint(num_tokens, (bs, len_sen)) # bs,len_sen
# gmlp = gMLP(num_tokens=num_tokens, len_sen=len_sen, dim=512, d_ff=1024)
# output = gmlp(input)
# print(output.shape)