-
Notifications
You must be signed in to change notification settings - Fork 173
/
gtrxl.py
324 lines (307 loc) · 12.7 KB
/
gtrxl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
"""
Gated Transformer XL (GTrXL) <link https://arxiv.org/abs/1910.06764 link> is a stabilized transformer architecture for reinforcement learning.
This document mainly includes:
- Pytorch implementation for GTrXL.
- An example to test GTrXL.
"""
from typing import Optional, Dict
import warnings
import numpy as np
import torch
import torch.nn as nn
import treetensor
from ding.torch_utils import GRUGatingUnit, build_normalization
from ding.torch_utils.network.nn_module import fc_block
from ding.torch_utils.network.gtrxl import PositionalEmbedding, Memory, AttentionXL
class GatedTransformerXLLayer(torch.nn.Module):
"""
**Overview**:
The basic layer design of Gated Transformer-XL. This module mainly includes AttentionXL,
Feed-Forward-Network, layer normalization, and GRU-gating.
"""
def __init__(
self,
input_dim: int,
head_dim: int,
hidden_dim: int,
head_num: int,
mlp_num: int,
dropout: nn.Module,
activation: nn.Module,
gru_gating: bool = True,
gru_bias: float = 2.
) -> None:
super(GatedTransformerXLLayer, self).__init__()
self.dropout = dropout
# Decide whether to use GRU-gating.
self.gating = gru_gating
if self.gating is True:
self.gate1 = GRUGatingUnit(input_dim, gru_bias)
self.gate2 = GRUGatingUnit(input_dim, gru_bias)
# Build attention block using the AttentionXL class,
# a feed-forward network with optional dropout, and two layer normalization layers.
self.attention = AttentionXL(
input_dim,
head_dim,
head_num,
dropout,
)
# Build Feed-Forward-Network.
layers = []
dims = [input_dim] + [hidden_dim] * (mlp_num - 1) + [input_dim]
for i in range(mlp_num):
layers.append(fc_block(dims[i], dims[i + 1], activation=activation))
if i != mlp_num - 1:
layers.append(self.dropout)
layers.append(self.dropout)
self.mlp = nn.Sequential(*layers)
# Build layer norm.
self.layernorm1 = build_normalization('LN')(input_dim)
self.layernorm2 = build_normalization('LN')(input_dim)
self.activation = activation
# delimiter
def forward(
self,
inputs: torch.Tensor,
pos_embedding: torch.Tensor,
u: torch.nn.Parameter,
v: torch.nn.Parameter,
memory: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
**Overview**:
The forward computation graph of GTrXL layer.
"""
# Concat memory with input across sequence dimension. The shape is: [full_sequence, batch_size, input_dim]
full_input = torch.cat([memory, inputs], dim=0)
# Forward calculation for GTrXL layer.
# In GTrXL, the layer normalization is put before the attention layer.
x1 = self.layernorm1(full_input)
# Attention module.
a1 = self.dropout(self.attention(inputs, pos_embedding, x1, u, v, mask=mask))
a1 = self.activation(a1)
# In GTrXL, gating layer replace the resnet layer in TrXL.
o1 = self.gate1(inputs, a1) if self.gating else inputs + a1
x2 = self.layernorm2(o1)
# Feed Forward Network.
m2 = self.dropout(self.mlp(x2))
o2 = self.gate2(o1, m2) if self.gating else o1 + m2
return o2
# delimiter
class GTrXL(nn.Module):
"""
**Overview**:
PyTorch implementation for GTrXL, which is used to model the long-term time dependency in reinforcement learning.
"""
def __init__(
self,
input_dim: int,
head_dim: int = 128,
embedding_dim: int = 256,
head_num: int = 2,
mlp_num: int = 2,
layer_num: int = 3,
memory_len: int = 64,
dropout_ratio: float = 0.,
activation: nn.Module = nn.ReLU(),
gru_gating: bool = True,
gru_bias: float = 2.,
use_embedding_layer: bool = True,
) -> None:
super(GTrXL, self).__init__()
assert embedding_dim % 2 == 0, 'embedding_dim={} should be even'.format(input_dim)
self.head_num = head_num
self.head_dim = head_dim
self.layer_num = layer_num
self.embedding_dim = embedding_dim
if isinstance(input_dim, list):
input_dim = np.prod(input_dim)
# Initialize embedding layer.
self.use_embedding_layer = use_embedding_layer
if self.use_embedding_layer:
self.embedding = fc_block(input_dim, embedding_dim, activation=activation)
# Initialize activate function.
self.activation = activation
# Initialize position embedding.
self.pos_embedding = PositionalEmbedding(embedding_dim)
# Memory to save hidden states of past segments. It will be initialized in the forward method to get its size dynamically.
self.memory = None
self.memory_len = memory_len
# Initialize GTrXL layers.
layers = []
# Put all the embedding_dims into a list.
# For the i-th layer, the input embedding is dims[i], while the output embedding is dims[i+1]
dims = [embedding_dim] + [embedding_dim] * layer_num
self.dropout = nn.Dropout(dropout_ratio) if dropout_ratio > 0 else nn.Identity()
for i in range(layer_num):
layers.append(
GatedTransformerXLLayer(
dims[i], head_dim, dims[i+1], head_num, mlp_num, self.dropout, self.activation, gru_gating,
gru_bias
)
)
self.layers = nn.Sequential(*layers)
# u and v are the parameters to compute global content bias and global positional bias.
self.u, self.v = (
torch.nn.Parameter(torch.zeros(self.head_num, self.head_dim)),
torch.nn.Parameter(torch.zeros(self.head_num, self.head_dim)),
)
# Create an attention mask for each different seq_len. In this way we don't need to create a new one each time we call the forward method.
self.att_mask = {}
# Create a pos embedding for each different seq_len. In this way we don't need to create a new one each time we call the forward method.
self.pos_embedding_dict = {}
# delimiter
def reset_memory(self, batch_size: Optional[int] = None, state: Optional[torch.Tensor] = None):
"""
**Overview**:
Reset the memory of GTrXL, which is called at the beginning of each episode.
Memory is used to save hidden states of past segments.
"""
# Reset the memory of GTrXL.
self.memory = Memory(memory_len=self.memory_len, layer_num=self.layer_num, embedding_dim=self.embedding_dim)
# If batch_size is not None, specify the batch_size when initializing the memory.
if batch_size is not None:
self.memory = Memory(self.memory_len, batch_size, self.embedding_dim, self.layer_num)
# If state is not None, add state into the memory.
elif state is not None:
self.memory.init(state)
# delimiter
def get_memory(self):
"""
**Overview**:
Access the memory of GTrXL.
"""
# Get the memory of GTrXL.
if self.memory is None:
return None
else:
return self.memory.get()
# delimiter
def forward(self, x: torch.Tensor, batch_first: bool = False, return_mem: bool = True) -> Dict[str, torch.Tensor]:
"""
**Overview**:
The forward computation graph of GTrXL.
"""
# If the first dimension of input x is batch_size,
# then reshape x from [batch_size ,sequence_length ,input_dim] to [sequence_length, batch_size, input_dim]
if batch_first:
x = torch.transpose(x, 1, 0)
cur_seq, bs = x.shape[:2]
# Get back memory.
memory = None if self.memory is None else self.memory.get()
# Abnormal case: no memory or memory shape mismatch.
if memory is None:
self.reset_memory(bs)
elif memory.shape[-2] != bs or memory.shape[-1] != self.embedding_dim:
warnings.warn(
"Memory {} and Input {} dimensions don't match,"
" this will cause the memory to be initialized to fit your input!".format(
list(memory.shape[-2:]), [x.shape[-2]] + [self.embedding_dim]
)
)
self.reset_memory(bs)
self.memory.to(x.device)
memory = self.memory.get()
# Pass through embedding layer.
if self.use_embedding_layer:
x = self.dropout(self.embedding(x))
# Get full sequence length: memory length + current length
prev_seq = self.memory_len
full_seq = cur_seq + prev_seq
# If the attention mask for current sequence length is already created, reuse the mask stored in ``self.att_mask`` .
if cur_seq in self.att_mask.keys():
attn_mask = self.att_mask[cur_seq]
# Otherwise, create a new attention mask and store it into ``self.att_mask`` .
else:
# For example, if cur_seq = 3, full_seq = 7, then the mask is:
# $$ \begin{matrix} 0 & 0 & 0 & 0 & 0 & 1 & 1 \\ 0 & 0 & 0 & 0 & 0 & 0 & 1 \\ 0 & 0 & 0 & 0 & 0 & 0 & 0 \end{matrix}$$
# This forces that the hidden state of current token is only associated with previous tokens.
attn_mask = (
torch.triu(
torch.ones((cur_seq, full_seq)),
diagonal=1 + prev_seq,
).bool().unsqueeze(-1).to(x.device)
)
self.att_mask[cur_seq] = attn_mask
# If the position encoding for current sequence length is already created, reuse it stored in ``self.pos_embedding_dict`` .
if cur_seq in self.pos_embedding_dict.keys():
pos_embedding = self.pos_embedding_dict[cur_seq]
# Otherwise, create a new position encoding and store it into ``self.pos_embedding_dict`` .
else:
pos_ips = torch.arange(full_seq - 1, -1, -1.0, dtype=torch.float) # full_seq
pos_embedding = self.pos_embedding(pos_ips.to(x.device))
self.pos_embedding_dict[cur_seq] = pos_embedding
pos_embedding = self.dropout(pos_embedding) # full_seq x 1 x embedding_dim
hidden_state = [x]
out = x
# Calculate results for each GTrXL layer.
for i in range(self.layer_num):
layer = self.layers[i]
out = layer(
out,
pos_embedding,
self.u,
self.v,
mask=attn_mask,
memory=memory[i],
)
hidden_state.append(out.clone())
out = self.dropout(out)
# Update the GTrXL memory.
self.memory.update(hidden_state)
# If the first dimension of output is required to be batch_size, then reshape x from [sequence_length, batch_size, input_dim] to [batch_size ,sequence_length ,input_dim].
if batch_first:
out = torch.transpose(out, 1, 0)
# Return memory is needed.
if return_mem:
output = treetensor.Object({"logit": out, "memory": memory})
else:
output = treetensor.Object({"logit": out})
return output
# delimiter
def test_gtrxl() -> None:
"""
**Overview**:
Test function of GTrXL.
"""
# Generate data for testing.
input_dim = 128
seq_len = 64
bs = 32
embedding_dim = 256
layer_num = 5
mem_len = 40
memory = [None, torch.rand(layer_num + 1, mem_len, bs, embedding_dim)]
# Test GTrXL under different situations.
for i in range(2):
m = memory[i]
model = GTrXL(
input_dim=input_dim,
head_dim=2,
embedding_dim=embedding_dim,
memory_len=mem_len,
head_num=2,
mlp_num=2,
layer_num=layer_num,
)
# Input shape: [sequence_length, batch_size, input_dim]
input = torch.rand(seq_len, bs, input_dim, requires_grad=True)
# Reset the model memory.
if m is None:
model.reset_memory(batch_size=bs)
else:
model.reset_memory(state=m)
output = model(input)
# Check the shape of output.
assert output['logit'].shape == (seq_len, bs, embedding_dim)
assert output['memory'].shape == (layer_num + 1, mem_len, bs, embedding_dim)
torch.sum(output['logit']).backward()
# Check the gradient.
assert isinstance(input.grad, torch.Tensor)
# Check memory.
memory_out = output['memory']
if m is not None:
assert torch.all(torch.eq(memory_out, m))
if __name__ == '__main__':
test_gtrxl()