Skip to content

Commit

Permalink
Update GC encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Amapocho committed Aug 22, 2022
1 parent 0f95103 commit 3006bf8
Showing 1 changed file with 161 additions and 0 deletions.
161 changes: 161 additions & 0 deletions vformer/encoder/gc.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,85 @@ def forward(self, x):
return x


class GlobalQueryGen(nn.Module):
def __init__(self, dim, input_resolution, window_size, num_heads):
"""
Args:
dim: feature size dimension.
input_resolution: input image resolution.
window_size: window size.
num_heads: number of heads.
For instance, repeating log(56/7) = 3 blocks, with input window dimension 56 and output window dimension 7 at
down-sampling ratio 2. Please check Fig.5 of GC ViT paper for details.
"""

super().__init__()
if input_resolution == 56:
self.to_q_global = nn.Sequential(
FeatExtract(dim, keep_dim=False),
FeatExtract(dim, keep_dim=False),
FeatExtract(dim, keep_dim=False),
)

elif input_resolution == 28:
self.to_q_global = nn.Sequential(
FeatExtract(dim, keep_dim=False),
FeatExtract(dim, keep_dim=False),
)

elif input_resolution == 14:

if window_size == 14:
self.to_q_global = nn.Sequential(FeatExtract(dim, keep_dim=True))

elif window_size == 7:
self.to_q_global = nn.Sequential(FeatExtract(dim, keep_dim=False))

elif input_resolution == 7:
self.to_q_global = nn.Sequential(FeatExtract(dim, keep_dim=True))

self.resolution = input_resolution
self.num_heads = num_heads
self.N = window_size * window_size
self.dim_head = dim // self.num_heads

def forward(self, x):
x = self.to_q_global(x)
x = x.permute(0, 2, 3, 1)
B = x.shape[0]
x = x.reshape(B, 1, self.N, self.num_heads, self.dim_head).permute(
0, 1, 3, 2, 4
)
return x


@ENCODER_REGISTRY.register()
class GCViTBlock(nn.Module):
"""
GCViT block based on: "Hatamizadeh et al.,
Global Context Vision Transformers <https://arxiv.org/abs/2206.09959>"
Parameters
----------
dim: feature size dimension.
input_resolution: input image resolution.
num_heads: number of attention head.
window_size: window size.
mlp_ratio: MLP ratio.
qkv_bias: bool argument for query, key, value learnable bias.
qk_scale: bool argument to scaling query, key.
drop: dropout rate.
attn_drop: attention dropout rate.
drop_path_rate: drop path rate.
drop_path_mode: drop path mode.
act_layer: activation function.
attention: attention block type.
norm_layer: normalization layer.
layer_scale: layer scaling coefficient.
"""

def __init__(
self,
dim,
Expand All @@ -93,6 +170,7 @@ def __init__(
norm_layer=nn.LayerNorm,
layer_scale=None,
):

super().__init__()
self.window_size = window_size
self.norm1 = norm_layer(dim)
Expand Down Expand Up @@ -234,3 +312,86 @@ def forward(self, x):
if self.downsample is None:
return x
return self.downsample(x)


class GCViTLayer(nn.Module):
"""
GCViT layer based on: "Hatamizadeh et al.,
Global Context Vision Transformers <https://arxiv.org/abs/2206.09959>"
Parameters
----------
dim: feature size dimension.
depth: number of layers in each stage.
input_resolution: input image resolution.
window_size: window size in each stage.
downsample: bool argument for down-sampling.
mlp_ratio: MLP ratio.
num_heads: number of heads in each stage.
qkv_bias: bool argument for query, key, value learnable bias.
qk_scale: bool argument to scaling query, key.
drop: dropout rate.
attn_drop: attention dropout rate.
drop_path_rate: drop path rate.
norm_layer: normalization layer.
layer_scale: layer scaling coefficient.
"""

def __init__(
self,
dim,
depth,
input_resolution,
num_heads,
window_size,
downsample=True,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path_rate=0.0,
norm_layer=nn.LayerNorm,
layer_scale=None,
):

super().__init__()
self.blocks = nn.ModuleList(
[
GCViTBlock(
dim=dim,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attention=WindowAttention
if (i % 2 == 0)
else WindowAttentionGlobal,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path_rate[i]
if isinstance(drop_path_rate, list)
else drop_path_rate,
norm_layer=norm_layer,
layer_scale=layer_scale,
input_resolution=input_resolution,
)
for i in range(depth)
]
)
self.downsample = (
None if not downsample else ReduceSize(dim=dim, norm_layer=norm_layer)
)
self.q_global_gen = GlobalQueryGen(
dim, input_resolution, window_size, num_heads
)

def forward(self, x):
q_global = self.q_global_gen(x.permute(0, 3, 1, 2))
for blk in self.blocks:
x = blk(x, q_global)
if self.downsample is None:
return x
return self.downsample(x)

0 comments on commit 3006bf8

Please sign in to comment.