-
Notifications
You must be signed in to change notification settings - Fork 41
/
attention_nystrom.py
69 lines (53 loc) · 3.07 KB
/
attention_nystrom.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import torch.nn as nn
import math
class NystromAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.head_dim = config["head_dim"]
self.num_head = config["num_head"]
self.num_landmarks = config["num_landmarks"]
self.seq_len = config["seq_len"]
if "inv_coeff_init_option" in config:
self.init_option = config["inv_init_coeff_option"]
else:
self.init_option = "original"
self.use_conv = "conv_kernel_size" in config
if self.use_conv:
self.conv = nn.Conv2d(
in_channels = self.num_head, out_channels = self.num_head,
kernel_size = (config["conv_kernel_size"], 1), padding = (config["conv_kernel_size"] // 2, 0),
bias = False,
groups = self.num_head)
def forward(self, Q, K, V, mask):
Q = Q * mask[:, None, :, None] / math.sqrt(math.sqrt(self.head_dim))
K = K * mask[:, None, :, None] / math.sqrt(math.sqrt(self.head_dim))
if self.num_landmarks == self.seq_len:
attn = torch.nn.functional.softmax(torch.matmul(Q, K.transpose(-1, -2)) - 1e9 * (1 - mask[:, None, None, :]), dim = -1)
X = torch.matmul(attn, V)
else:
Q_landmarks = Q.reshape(-1, self.num_head, self.num_landmarks, self.seq_len // self.num_landmarks, self.head_dim).mean(dim = -2)
K_landmarks = K.reshape(-1, self.num_head, self.num_landmarks, self.seq_len // self.num_landmarks, self.head_dim).mean(dim = -2)
kernel_1 = torch.nn.functional.softmax(torch.matmul(Q, K_landmarks.transpose(-1, -2)), dim = -1)
kernel_2 = torch.nn.functional.softmax(torch.matmul(Q_landmarks, K_landmarks.transpose(-1, -2)), dim = -1)
kernel_3 = torch.nn.functional.softmax(torch.matmul(Q_landmarks, K.transpose(-1, -2)) - 1e9 * (1 - mask[:, None, None, :]), dim = -1)
X = torch.matmul(torch.matmul(kernel_1, self.iterative_inv(kernel_2)), torch.matmul(kernel_3, V))
if self.use_conv:
X += self.conv(V * mask[:, None, :, None])
return X
def iterative_inv(self, mat, n_iter = 6):
I = torch.eye(mat.size(-1), device = mat.device)
K = mat
# The entries of K are positive and ||K||_{\infty} = 1 due to softmax
if self.init_option == "original":
# This original implementation is more conservative to compute coefficient of Z_0.
V = 1 / torch.max(torch.sum(K, dim = -2)) * K.transpose(-1, -2)
else:
# This is the exact coefficient computation, 1 / ||K||_1, of initialization of Z_0, leading to faster convergence.
V = 1 / torch.max(torch.sum(K, dim = -2), dim = -1).values[:, :, None, None] * K.transpose(-1, -2)
for _ in range(n_iter):
KV = torch.matmul(K, V)
V = torch.matmul(0.25 * V, 13 * I - torch.matmul(KV, 15 * I - torch.matmul(KV, 7 * I - KV)))
return V
def extra_repr(self):
return f'num_landmarks={self.num_landmarks}, seq_len={self.seq_len}'