Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Global Context ViT #95

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion vformer/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from .memory_efficient import MemoryEfficientAttention
from .spatial import SpatialAttention
from .vanilla import VanillaSelfAttention
from .window import WindowAttention
from .window import WindowAttention, WindowAttentionGlobal
114 changes: 114 additions & 0 deletions vformer/attention/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,117 @@ def forward(self, x, mask=None):
x = self.to_out_2(x)

return x


@ATTENTION_REGISTRY.register()
class WindowAttentionGlobal(nn.Module):
"""
Parameters
----------
dim: int
Number of input channels.
window_size : int or tuple[int]
The height and width of the window.
num_heads: int
Number of attention heads.
qkv_bias :bool, default is True
If True, add a learnable bias to query, key, value.
qk_scale: float, optional
Override default qk scale of head_dim ** -0.5 if set
attn_dropout: float, optional
Dropout rate
proj_dropout: float, optional
Dropout rate
"""

def __init__(
self,
dim,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_dropout=0.0,
proj_dropout=0.0,
):
super(WindowAttention, self).__init__()

self.dim = dim
self.window_size = pair(window_size)
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = qk_scale or self.head_dim**-0.5
self.qkv_bias = True
self.relative_position_bias_table = nn.Parameter(
torch.zeros(
(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads
)
)
relative_position_index = get_relative_position_bias_index(self.window_size)
self.register_buffer("relative_position_index", relative_position_index)

self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.to_out_1 = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(attn_dropout))
self.to_out_2 = nn.Sequential(nn.Linear(dim, dim), nn.Dropout(proj_dropout))
nn.init.trunc_normal_(self.relative_position_bias_table, std=0.2)

def forward(self, x, q_global, mask=None):
"""
Parameters
----------
x: torch.Tensor
input Tensor
mask: torch.Tensor
Attention mask used for shifted window attention, if None, window attention will be used,
else attention mask will be taken into consideration.
for better understanding you may refer `this <https://github.com/microsoft/Swin-Transformer/issues/38>`
Returns
----------
torch.Tensor
Returns output tensor by applying Window-Attention or Shifted-Window-Attention on input tensor
"""

B_, N, C = x.shape
B = q_global.shape[0]

kv = (
self.qkv(x)
.reshape(
B_,
N,
2,
self.num_heads,
C // self.num_heads,
)
.permute(2, 0, 3, 1, 4)
)
k, v = kv[0], kv[1]
q_global = q_global.repeat(B_ // B, 1, 1, 1)
q = q_global.reshape(B_, self.num_heads, N, C // self.num_heads)
q = q * self.scale
attn = q @ k.transpose(-2, -1)

relative_position_bias = (
self.relative_position_bias_table[self.relative_position_index.view(-1)]
.view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1,
)
.permute(2, 0, 1)
.contiguous()
)
attn = attn + relative_position_bias.unsqueeze(0)

if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
1
).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)

attn = self.to_out_1(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.to_out_2(x)

return x
1 change: 1 addition & 0 deletions vformer/encoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .convvt import ConvVTStage
from .cross import CrossEncoder
from .embedding import *
from .gc import GCViTBlock, GCViTLayer
from .nn import FeedForward
from .perceiver_io import PerceiverIOEncoder
from .pyramid import PVTEncoder
Expand Down
Loading