This repository includes official implementation and model weights of Data-Efficient Multi-Scale Fusion Vision Transformer.
Vision transformer (ViT) demonstrates significant potential in image classification with massive data, but struggles with small-scale datasets. To this end, this paper proposes to address this data inefficiency by introducing multi-scale tokens, which provides the image prior of multiple scales and enables learning scale-invariant features. Our model generates tokens of varying scales from images using different patch sizes, where each token of the larger scale is linked to a set of tokens of other smaller scales based on spatial correspondences. Through a regional cross-scale interaction module, tokens of different scales fuse regionally to enhance the learning of local structures.Additionally, we implement a data augmentation schedule to refine training. Extensive experiments on image classification demonstrate our approach surpasses DeiT and other multi-scale transformer methods on small-scale datasets.
- Multi-Scale Tokenization
- Regional Cross-Scale Interaction
To install requirements:
conda create -n dems python=3.8
pip install -r requirements.txt
The root paths of data are set to /path/to/dataset
. Please set the root paths accordingly.
CIFAR10
, CIFAR100
, FashionMNIST
, EMNIST
datasets provided by torchvision
.
Download and extract Caltech101 train and val images from https://www.vision.caltech.edu/datasets/.
The directory structure is the standard layout for the torchvision datasets.ImageFolder
, and the training and validation data is expected to be in the train/
folder and val/
folder respectively.
Set hyperparameters and GPU IDs in ./config/pretrain/dems_small_pretrain.py
.
Run the following command to train DEMS-ViT-S on CIFAR100 for 800 epochs, with random initialization on a single node with multiple gpus:
python main_pretrain --model dems_small --batch_size 256 --epochs 800 --dataset CIFAR100 --data_path /path/to/CIFAR100
Set hyperparameters and GPU IDs in ./config/pretrain/dems_small_finetune.py
.
Run the following command to finetune DEMS-ViT-S on CIFAR100 for 100 epochs:
python main_finetune --model dems_small --batch_size 256 --epochs 100 --dataset CIFAR100 --data_path /path/to/CIFAR100 --pretrained_weight /path/pretrained
We provide models trained on CIFAR, EMNIST, FASHIONNIST, and CALTECH101 here. Particularly, we train on CALTECH101 with the input size of 256x256 and patch size of 16.
Name | #FLOPs | #Params | Dataset | Acc@1 | URL |
---|---|---|---|---|---|
DEMS-ViT-Ti | 1.6 | 5.6M | CIFAR10 CIFAR100 FASHIONMNIST EMNIST CALTECH101 |
96.03 80.60 95.59 99.56 86.56 |
model model model model model |
DEMS-ViT-S | 5.8 | 22.3M | CIFAR10 CIFAR100 FASHIONMNIST EMNIST CALTECH101 |
96.20 83.30 95.99 99.58 86.88 |
model model model model model |
We provide fine-tuned models on CIFAR, which can be found here.
Name | Dataset | Acc@1 | URL |
---|---|---|---|
DEMS-ViT-Ti | CIFAR10 CIFAR100 |
96.74 83.50 |
model model |
DEMS-ViT-S | CIFAR10 CIFAR100 |
97.76 85.16 |
model model |
This project is under the CC-BY-NC 4.0 license. See LICENSE for details.