Skip to content

sayakpaul/probing-vits

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

36 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Probing ViTs

TensorFlow 2.8 HugginFace badge

By Aritra Roy Gosthipaty and Sayak Paul (equal contribution)

In this repository, we provide tools to probe into the representations learned by different families of Vision Transformers (supervised pre-training with ImageNet-21k, ImageNet-1k, distillation, self-supervised pre-training):

  • Original ViT [1]
  • DeiT [2]
  • DINO [3]

We hope these tools will prove to be useful for the community. Please follow along with this post on keras.io for a better navigation through the repository.

Updates

Self-attention visualization

Original Image Attention Maps Attention Maps Overlayed
original image attention maps attention maps overlay
output-dino.mp4

Original Video Source

output-dog.mp4

Original Video Source

Supervised salient representations

In the DINO blog post, the authors show a video with the following caption:

The original video is shown on the left. In the middle is a segmentation example generated by a supervised model, and on the right is one generated by DINO.

A screenshot of the video is as follows:

image

We obtain the attention maps generated with the supervised pre-trained model and find that they are not that salient w.r.t the DINO model. We observe a similar behaviour in our experiments as well. The figure below shows the attention heatmaps extracted with a ViT-B16 model pre-trained (supervised) using ImageNet-1k:

Dinosaur Dog

We used this Colab Notebook to conduct this experiment.

Hugging Face Spaces

You can now probe into the ViTs with your own input images.

Attention Heat Maps Attention Rollout
Generic badge Generic badge

Visualizing mean attention distances

Methods

We don't propose any novel methods of probing the representations of neural networks. Instead we take the existing works and implement them in TensorFlow.

  • Mean attention distance [1, 4]
  • Attention Rollout [5]
  • Visualization of the learned projection filters [1]
  • Visualization of the learned positioanl embeddings
  • Attention maps from individual attention heads [3]
  • Generation of attention heatmaps from videos [3]

Another interesting repository that also visualizes ViTs in PyTorch: https://github.com/jacobgil/vit-explain.

Notes

We first implemented the above-mentioned architectures in TensorFlow and then we populated the pre-trained parameters into them using the official codebases. In order to validate this, we evaluated the implementations on the ImageNet-1k validation set and ensured that the reported top-1 accuracies matched.

We value the spirit of open-source. So, if you spot any bugs in the code or see a scope for improvement don't hesitate to open up an issue or contribute a PR. We'd very much appreciate it.

Navigating through the codebase

Our ViT implementations are in vit. We provide utility notebooks in the notebooks directory which contains the following:

DeiT-related code has its separate repository: https://github.com/sayakpaul/deit-tf.

Models

Here are the links to the models where the pre-trained parameters were populated:

Training and visualizing with small datasets

Coming soon!

References

[1] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale: https://arxiv.org/abs/2010.11929

[2] DeiT: https://arxiv.org/abs/2012.12877

[3] DINO: https://arxiv.org/abs/2104.14294

[4] Do Vision Transformers See Like Convolutional Neural Networks?: https://arxiv.org/abs/2108.08810

[5] Quantifying Attention Flow in Transformers: https://arxiv.org/abs/2005.00928

Acknowledgements