-
Notifications
You must be signed in to change notification settings - Fork 2
/
models.py
80 lines (72 loc) · 2.16 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#! -*- coding: utf-8 -*-
# RNN-α 模型实现
# tensorflow 1.15 + bert4keras 0.11.4 测试通过
from bert4keras.models import *
from lru import LRU
from slru import SLRU
from rwkv import RWKV
RNN = LRU # SLRU、RWKV
class RNN_alpha(RoFormerV2):
"""RNN-α
改动:基本模块换成RNN
"""
def initializer(self, shape, dtype=None, order=2, gain=1.0):
return super(RNN_alpha, self).initializer(shape, dtype, order, gain)
def apply_main_layers(self, inputs, index):
"""RNN-α 的主体是基于RNN的模块
顺序:RNN --> Add --> LN --> FFN --> Add --> LN
"""
x = inputs
rnn_name = 'Transformer-%d-RNN' % index
ffn_name = 'Transformer-%d-FFN' % index
xi = x
x = self.apply(
inputs=x,
layer=RNN,
units=(2 if RNN is SLRU else 1) * self.hidden_size,
use_bias=False,
kernel_initializer=self.initializer,
name=rnn_name
)
x = self.apply(
inputs=x,
layer=Dropout,
rate=self.dropout_rate,
name='%s-Dropout' % rnn_name
)
x = self.apply(inputs=[xi, x], layer=Add, name='%s-Add' % rnn_name)
x = self.apply(
inputs=x,
layer=LayerNormalization,
zero_mean=False,
scale=False,
offset=False,
epsilon=1e-12,
name='%s-Norm' % rnn_name
)
xi = x
x = self.apply(
inputs=x,
layer=FeedForward,
units=self.intermediate_size,
kernel_initializer=self.initializer,
use_bias=False,
name=ffn_name
)
x = self.apply(
inputs=x,
layer=Dropout,
rate=self.dropout_rate,
name='%s-Dropout' % ffn_name
)
x = self.apply(inputs=[xi, x], layer=Add, name='%s-Add' % rnn_name)
x = self.apply(
inputs=x,
layer=LayerNormalization,
zero_mean=False,
scale=False,
offset=False,
epsilon=1e-12,
name='%s-Norm' % ffn_name
)
return x