Skip to content

Commit

Permalink
Expose attention matrices for ViT architectures (#74)
Browse files Browse the repository at this point in the history
* Expose attention matrices for ViT architectures

* Changed parameter names.

Co-authored-by: Laurent Navarro <[email protected]>
Co-authored-by: Martins Bruveris <[email protected]>
  • Loading branch information
3 people authored Oct 24, 2022
1 parent d8d5816 commit 0fecb65
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions tfimm/architectures/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Copyright 2021 Martins Bruveris
"""
from collections import OrderedDict
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

Expand Down Expand Up @@ -145,7 +146,9 @@ def __init__(
self.proj = tf.keras.layers.Dense(units=embed_dim, name="proj")
self.proj_drop = tf.keras.layers.Dropout(rate=drop_rate)

def call(self, x, training=False):
def call(self, x, training=False, return_features=False):
features = OrderedDict()

# B (batch size), N (sequence length), D (embedding dimension),
# H (number of heads)
batch_size, seq_length = tf.unstack(tf.shape(x)[:2])
Expand All @@ -157,14 +160,15 @@ def call(self, x, training=False):
attn = self.scale * tf.linalg.matmul(q, k, transpose_b=True) # (B, H, N, N)
attn = tf.nn.softmax(attn, axis=-1) # (B, H, N, N)
attn = self.attn_drop(attn, training=training)
features["attn"] = attn

x = tf.linalg.matmul(attn, v) # (B, H, N, D/H)
x = tf.transpose(x, (0, 2, 1, 3)) # (B, N, H, D/H)
x = tf.reshape(x, (batch_size, seq_length, -1)) # (B, N, D)

x = self.proj(x)
x = self.proj_drop(x, training=training)
return x
return (x, features) if return_features else x


class ViTBlock(tf.keras.layers.Layer):
Expand Down Expand Up @@ -212,10 +216,14 @@ def __init__(
name="mlp",
)

def call(self, x, training=False):
def call(self, x, training=False, return_features=False):
features = OrderedDict()
shortcut = x
x = self.norm1(x, training=training)
x = self.attn(x, training=training)
x = self.attn(x, training=training, return_features=return_features)
if return_features:
x, mha_features = x
features["attn"] = mha_features["attn"]
x = self.drop_path(x, training=training)
x = x + shortcut

Expand All @@ -224,7 +232,7 @@ def call(self, x, training=False):
x = self.mlp(x, training=training)
x = self.drop_path(x, training=training)
x = x + shortcut
return x
return (x, features) if return_features else x


class HybridEmbeddings(tf.keras.layers.Layer):
Expand Down Expand Up @@ -397,11 +405,11 @@ def dummy_inputs(self) -> tf.Tensor:

@property
def feature_names(self) -> List[str]:
return (
["patch_embedding"]
+ [f"block_{j}" for j in range(self.cfg.nb_blocks)]
+ ["features_all", "features", "logits"]
)
"""
Names of features, returned when calling ``call`` with ``return_features=True``.
"""
_, features = self(self.dummy_inputs, return_features=True)
return list(features.keys())

def transform_pos_embed(self, target_cfg: ViTConfig):
return interpolate_pos_embeddings(
Expand All @@ -412,7 +420,7 @@ def transform_pos_embed(self, target_cfg: ViTConfig):
)

def forward_features(self, x, training=False, return_features=False):
features = {}
features = OrderedDict()
batch_size = tf.shape(x)[0]

x, grid_size = self.patch_embed(x, return_shape=True)
Expand All @@ -436,7 +444,10 @@ def forward_features(self, x, training=False, return_features=False):
features["patch_embedding"] = x

for j, block in enumerate(self.blocks):
x = block(x, training=training)
x = block(x, training=training, return_features=return_features)
if return_features:
x, block_features = x
features[f"block_{j}/attn"] = block_features["attn"]
features[f"block_{j}"] = x
x = self.norm(x, training=training)
features["features_all"] = x
Expand Down

0 comments on commit 0fecb65

Please sign in to comment.