Skip to content

Commit

Permalink
just remove PreNorm wrapper from all ViTs, as it is unlikely to chang…
Browse files Browse the repository at this point in the history
…e at this point
  • Loading branch information
lucidrains committed Aug 14, 2023
1 parent 4264efd commit ad7d0df
Show file tree
Hide file tree
Showing 21 changed files with 137 additions and 232 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.4.2',
version = '1.4.3',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
15 changes: 5 additions & 10 deletions vit_pytorch/ats_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,11 @@ def forward(self, attn, value, mask):

# classes

class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
Expand All @@ -138,6 +131,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., output_num_token
self.heads = heads
self.scale = dim_head ** -0.5

self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)

Expand All @@ -154,6 +148,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., output_num_token
def forward(self, x, *, mask):
num_tokens = x.shape[1]

x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

Expand Down Expand Up @@ -189,8 +184,8 @@ def __init__(self, dim, depth, max_tokens_per_depth, heads, dim_head, mlp_dim, d
self.layers = nn.ModuleList([])
for _, output_num_tokens in zip(range(depth), max_tokens_per_depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, output_num_tokens = output_num_tokens, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
Attention(dim, output_num_tokens = output_num_tokens, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))

def forward(self, x):
Expand Down
15 changes: 5 additions & 10 deletions vit_pytorch/cait.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,11 @@ def __init__(self, dim, fn, depth):
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale

class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
Expand All @@ -72,6 +65,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
self.heads = heads
self.scale = dim_head ** -0.5

self.norm = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

Expand All @@ -89,6 +83,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
def forward(self, x, context = None):
b, n, _, h = *x.shape, self.heads

x = self.norm(x)
context = x if not exists(context) else torch.cat((x, context), dim = 1)

qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
Expand All @@ -115,8 +110,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., layer_dro

for ind in range(depth):
self.layers.append(nn.ModuleList([
LayerScale(dim, PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), depth = ind + 1),
LayerScale(dim, PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)), depth = ind + 1)
LayerScale(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), depth = ind + 1),
LayerScale(dim, FeedForward(dim, mlp_dim, dropout = dropout), depth = ind + 1)
]))
def forward(self, x, context = None):
layers = dropout_layers(self.layers, dropout = self.layer_dropout)
Expand Down
21 changes: 7 additions & 14 deletions vit_pytorch/cross_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,13 @@ def exists(val):
def default(val, d):
return val if exists(val) else d

# pre-layernorm

class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)

# feedforward

class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
Expand All @@ -47,6 +38,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
self.heads = heads
self.scale = dim_head ** -0.5

self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)

Expand All @@ -60,6 +52,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):

def forward(self, x, context = None, kv_include_self = False):
b, n, _, h = *x.shape, self.heads
x = self.norm(x)
context = default(context, x)

if kv_include_self:
Expand All @@ -86,8 +79,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
self.norm = nn.LayerNorm(dim)
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))

def forward(self, x):
Expand Down Expand Up @@ -121,8 +114,8 @@ def __init__(self, sm_dim, lg_dim, depth, heads, dim_head, dropout):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
ProjectInOut(sm_dim, lg_dim, PreNorm(lg_dim, Attention(lg_dim, heads = heads, dim_head = dim_head, dropout = dropout))),
ProjectInOut(lg_dim, sm_dim, PreNorm(sm_dim, Attention(sm_dim, heads = heads, dim_head = dim_head, dropout = dropout)))
ProjectInOut(sm_dim, lg_dim, Attention(lg_dim, heads = heads, dim_head = dim_head, dropout = dropout)),
ProjectInOut(lg_dim, sm_dim, ttention(sm_dim, heads = heads, dim_head = dim_head, dropout = dropout))
]))

def forward(self, sm_tokens, lg_tokens):
Expand Down
17 changes: 6 additions & 11 deletions vit_pytorch/cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,11 @@ def forward(self, x):
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b

class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)

class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
LayerNorm(dim),
nn.Conv2d(dim, dim * mult, 1),
nn.GELU(),
nn.Dropout(dropout),
Expand Down Expand Up @@ -75,6 +67,7 @@ def __init__(self, dim, proj_kernel, kv_proj_stride, heads = 8, dim_head = 64, d
self.heads = heads
self.scale = dim_head ** -0.5

self.norm = LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)

Expand All @@ -89,6 +82,8 @@ def __init__(self, dim, proj_kernel, kv_proj_stride, heads = 8, dim_head = 64, d
def forward(self, x):
shape = x.shape
b, n, _, y, h = *shape, self.heads

x = self.norm(x)
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))

Expand All @@ -107,8 +102,8 @@ def __init__(self, dim, proj_kernel, kv_proj_stride, depth, heads, dim_head = 64
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_mult, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
Expand Down
27 changes: 8 additions & 19 deletions vit_pytorch/deepvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,11 @@
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x

class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
Expand All @@ -40,6 +26,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
self.heads = heads
self.scale = dim_head ** -0.5

self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

self.dropout = nn.Dropout(dropout)
Expand All @@ -59,6 +46,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):

def forward(self, x):
b, n, _, h = *x.shape, self.heads
x = self.norm(x)

qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

Expand Down Expand Up @@ -86,13 +75,13 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x)
x = ff(x)
x = attn(x) + x
x = ff(x) + x
return x

class DeepViT(nn.Module):
Expand Down
18 changes: 6 additions & 12 deletions vit_pytorch/local_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,6 @@ def forward(self, x, **kwargs):
x = self.fn(x, **kwargs)
return torch.cat((cls_token, x), dim = 1)

# prenorm

class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)

# feed forward related classes

class DepthWiseConv2d(nn.Module):
Expand All @@ -52,6 +42,7 @@ class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Conv2d(dim, hidden_dim, 1),
nn.Hardswish(),
DepthWiseConv2d(hidden_dim, hidden_dim, 3, padding = 1),
Expand All @@ -77,6 +68,7 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
self.heads = heads
self.scale = dim_head ** -0.5

self.norm = nn.LayerNorm(dim)
self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
Expand All @@ -88,6 +80,8 @@ def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):

def forward(self, x):
b, n, _, h = *x.shape, self.heads

x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

Expand All @@ -106,8 +100,8 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
ExcludeCLS(Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))))
Residual(Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
ExcludeCLS(Residual(FeedForward(dim, mlp_dim, dropout = dropout)))
]))
def forward(self, x):
for attn, ff in self.layers:
Expand Down
17 changes: 10 additions & 7 deletions vit_pytorch/max_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,20 @@ def cast_tuple(val, length = 1):

# helper classes

class PreNormResidual(nn.Module):
class Residual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn

def forward(self, x):
return self.fn(self.norm(x)) + x
return self.fn(x) + x

class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim),
nn.GELU(),
nn.Dropout(dropout),
Expand Down Expand Up @@ -132,6 +132,7 @@ def __init__(
self.heads = dim // dim_head
self.scale = dim_head ** -0.5

self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, dim * 3, bias = False)

self.attend = nn.Sequential(
Expand Down Expand Up @@ -160,6 +161,8 @@ def __init__(
def forward(self, x):
batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads

x = self.norm(x)

# flatten

x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d')
Expand Down Expand Up @@ -259,13 +262,13 @@ def __init__(
shrinkage_rate = mbconv_shrinkage_rate
),
Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w), # block-like attention
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
Residual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
Residual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'),

Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w), # grid-like attention
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
Residual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
Residual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'),
)

Expand Down
Loading

0 comments on commit ad7d0df

Please sign in to comment.