Skip to content

Commit

Permalink
update to release
Browse files Browse the repository at this point in the history
  • Loading branch information
IcarusWizard committed Nov 28, 2021
1 parent 2d14ada commit e77510c
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 14 deletions.
9 changes: 6 additions & 3 deletions mae_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,15 @@

setup_seed(args.seed)

assert args.batch_size % args.max_device_batch_size == 0
steps_per_update = args.batch_size // args.max_device_batch_size
batch_size = args.batch_size
load_batch_size = min(args.max_device_batch_size, batch_size)

assert batch_size % load_batch_size == 0
steps_per_update = batch_size // load_batch_size

train_dataset = torchvision.datasets.CIFAR10('data', train=True, download=True, transform=Compose([ToTensor(), Normalize(0.5, 0.5)]))
val_dataset = torchvision.datasets.CIFAR10('data', train=False, download=True, transform=Compose([ToTensor(), Normalize(0.5, 0.5)]))
dataloader = torch.utils.data.DataLoader(train_dataset, args.max_device_batch_size, shuffle=True, num_workers=4)
dataloader = torch.utils.data.DataLoader(train_dataset, load_batch_size, shuffle=True, num_workers=4)
writer = SummaryWriter(os.path.join('logs', 'cifar10', 'mae-pretrain'))
device = 'cuda' if torch.cuda.is_available() else 'cpu'

Expand Down
Binary file added pic/mae-cifar10-reconstruction.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
39 changes: 34 additions & 5 deletions readme.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,37 @@
Implementation of [*KaiMing He el.al. Masked Autoencoders Are Scalable Vision Learners*](https://arxiv.org/abs/2111.06377).
### Implementation of [*KaiMing He el.al. Masked Autoencoders Are Scalable Vision Learners*](https://arxiv.org/abs/2111.06377).

Due to limit resource available, we only test the model on cifar10. We mainly want to reproduce the result that **pre-training an ViT with MAE can achieve a better result than directly trained in supervised learning with labels**. This should be an evidence of **self-supervised learning is more data efficient than supervised learning**.

|Model|Test Acc|
|-----|--------|
|ViT-T||
|ViT-T-MAE||
We mainly follow the implementation details in the paper. However, due to difference between Cifar10 and ImageNet, we make some modification:
- we use vit-tiny instead of vit-base.
- since Cifar10 have only 50k training data, we increase the pretraining epoch from 400 to 2000, and the warmup epoch from 40 to 200. We noticed that, the loss is still decreasing after 2000 epoches.
- we decrease the batch size for training the classifier from 1024 to 128 to mitigate the overfitting.

### Installation
`pip install -r requirements.txt`

### Run
```bash
# pretrained with mae
python mae_pretrain.py

# train classifier from scratch
python train_classifier.py

# train classifier from pretrained model
python train_classifier.py --pretrained_model_path vit-t-mae.pt --output_model_path vit-t-classifier-from_pretrained.pt
```

See logs by `tensorboard --logdir logs`.

### Result
|Model|Validation Acc|
|-----|--------------|
|ViT-T w/o pretrain|74.13|
|ViT-T w/ pretrain|**89.77**|

Weights and tensorboard logs are in github release. You can also view the tensorboard at [tensorboard.dev](https://tensorboard.dev/experiment/zngzZ89bTpyM1B2zVrD7Yw/#scalars).

Visualization of the first 16 images on Cifar10 validation dataset:

![avatar](pic/mae-cifar10-reconstruction.png)
15 changes: 9 additions & 6 deletions train_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,29 @@
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--batch_size', type=int, default=1024)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--max_device_batch_size', type=int, default=256)
parser.add_argument('--base_learning_rate', type=float, default=1e-3)
parser.add_argument('--weight_decay', type=float, default=0.05)
parser.add_argument('--total_epoch', type=int, default=100)
parser.add_argument('--warmup_epoch', type=int, default=5)
parser.add_argument('--pretrained_model_path', type=str, default=None)
parser.add_argument('--output_model_path', type=str, default='vit-t-mae-scratch.pt')
parser.add_argument('--output_model_path', type=str, default='vit-t-classifier-from_scratch.pt')

args = parser.parse_args()

setup_seed(args.seed)

assert args.batch_size % args.max_device_batch_size == 0
steps_per_update = args.batch_size // args.max_device_batch_size
batch_size = args.batch_size
load_batch_size = min(args.max_device_batch_size, batch_size)

assert batch_size % load_batch_size == 0
steps_per_update = batch_size // load_batch_size

train_dataset = torchvision.datasets.CIFAR10('data', train=True, download=True, transform=Compose([ToTensor(), Normalize(0.5, 0.5)]))
val_dataset = torchvision.datasets.CIFAR10('data', train=False, download=True, transform=Compose([ToTensor(), Normalize(0.5, 0.5)]))
train_dataloader = torch.utils.data.DataLoader(train_dataset, args.max_device_batch_size, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, args.max_device_batch_size, shuffle=False, num_workers=4)
train_dataloader = torch.utils.data.DataLoader(train_dataset, load_batch_size, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, load_batch_size, shuffle=False, num_workers=4)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

if args.pretrained_model_path is not None:
Expand Down

0 comments on commit e77510c

Please sign in to comment.