Skip to content

Commit

Permalink
add all dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 2, 2020
1 parent 66daca3 commit da19be5
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 12 deletions.
47 changes: 36 additions & 11 deletions compressive_transformer_pytorch/compressive_transformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,21 +99,39 @@ def forward(self, mem):

# feedforward

class GELU_(nn.Module):
def forward(self, x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_

class FeedForward(nn.Module):
def __init__(self, dim, mult = 4):
def __init__(self, dim, mult = 4, dropout = 0., activation = None, glu = False):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult),
nn.LeakyReLU(inplace = True),
nn.Linear(dim * mult, dim)
)
activation = default(activation, GELU)

self.glu = glu
self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1))
self.act = activation()
self.dropout = nn.Dropout(dropout)
self.w2 = nn.Linear(dim * mult, dim)

def forward(self, x, **kwargs):
return self.net(x)
if not self.glu:
x = self.w1(x)
x = self.act(x)
else:
x, v = self.w1(x).chunk(2, dim=-1)
x = self.act(x) * v

x = self.dropout(x)
x = self.w2(x)
return x

# attention.

class SelfAttention(nn.Module):
def __init__(self, dim, seq_len, mem_len, cmem_len, cmem_ratio = 4, heads = 8):
def __init__(self, dim, seq_len, mem_len, cmem_len, cmem_ratio = 4, heads = 8, attn_dropout = 0., dropout = 0.):
super().__init__()
self.heads = heads
self.dim_head = dim // heads
Expand All @@ -129,6 +147,9 @@ def __init__(self, dim, seq_len, mem_len, cmem_len, cmem_ratio = 4, heads = 8):
self.to_kv = nn.Linear(dim, dim * 2, bias = False)
self.to_out = nn.Linear(dim, dim)

self.attn_dropout = nn.Dropout(attn_dropout)
self.dropout = nn.Dropout(dropout)

def forward(self, x, memories = None, pos_emb = None, **kwargs):
b, t, e, h, dim_h = *x.shape, self.heads, self.dim_head

Expand Down Expand Up @@ -162,10 +183,12 @@ def forward(self, x, memories = None, pos_emb = None, **kwargs):
dots.masked_fill_(mask[None, None, ...], float('-inf'))

attn = dots.softmax(dim=-1)
attn = self.attn_dropout(attn)

out = torch.einsum('bhij,bhjd->bhid', attn, v)
out = out.transpose(1, 2).reshape(b, t, -1)
logits = self.to_out(out)
logits = self.dropout(logits)

new_mem = mem
new_cmem = cmem
Expand Down Expand Up @@ -202,12 +225,13 @@ def forward(self, x, memories = None, pos_emb = None, **kwargs):
full_attn(q, cmem_k, cmem_v)
)


return logits, Memory(mem = new_mem, cmem = new_cmem), aux_loss

# transformer

class CompressiveTransformer(nn.Module):
def __init__(self, num_tokens, dim, seq_len, depth, mem_len = None, cmem_len = None, cmem_ratio = 4, heads = 8, gru_gated_residual = True):
def __init__(self, num_tokens, dim, seq_len, depth, mem_len = None, cmem_len = None, cmem_ratio = 4, heads = 8, gru_gated_residual = True, attn_dropout = 0., ff_dropout = 0., attn_layer_dropout = 0.):
super().__init__()
mem_len = default(mem_len, seq_len)
cmem_len = default(cmem_len, mem_len // cmem_ratio)
Expand All @@ -222,8 +246,8 @@ def __init__(self, num_tokens, dim, seq_len, depth, mem_len = None, cmem_len = N

wrapper = partial(GRUGating, dim) if gru_gated_residual else Residual

self.attn_layers = nn.ModuleList([wrapper(PreNorm(dim, SelfAttention(dim, seq_len, mem_len, cmem_len, cmem_ratio, heads))) for _ in range(depth)])
self.ff_layers = nn.ModuleList([wrapper(PreNorm(dim, FeedForward(dim))) for _ in range(depth)])
self.attn_layers = nn.ModuleList([wrapper(PreNorm(dim, SelfAttention(dim, seq_len, mem_len, cmem_len, cmem_ratio, heads, dropout = attn_layer_dropout, attn_dropout = attn_dropout))) for _ in range(depth)])
self.ff_layers = nn.ModuleList([wrapper(PreNorm(dim, FeedForward(dim, dropout = ff_dropout))) for _ in range(depth)])

def forward(self, x, memories = None):
x = self.token_emb(x)
Expand All @@ -246,6 +270,7 @@ def forward(self, x, memories = None):
for attn, ff, m, c in zip(self.attn_layers, self.ff_layers, mem, cmem):
x, (mem_out, cmem_out), layer_aux_loss = attn(x, memories = (m, c), pos_emb = pos_emb)
x, = ff(x)

next_mem.append(mem_out)
next_cmem.append(cmem_out)
aux_loss = aux_loss + layer_aux_loss
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'compressive_transformer_pytorch',
packages = find_packages(),
version = '0.0.2',
version = '0.0.3',
license='MIT',
description = 'Implementation of Compressive Transformer in Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit da19be5

Please sign in to comment.