Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for H5 data, improved scripts and data handling #275

Merged
merged 41 commits into from
Jul 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
1814c47
moco vit
vturrisi Jul 8, 2022
7acb1f8
wip
vturrisi Jul 8, 2022
e038f91
wip
vturrisi Jul 8, 2022
c9efc6c
script
vturrisi Jul 8, 2022
6cc4b1a
fixes
vturrisi Jul 8, 2022
80972d1
fixes
vturrisi Jul 8, 2022
ece07e7
fixes
vturrisi Jul 8, 2022
654b33f
fixes
vturrisi Jul 8, 2022
a887f21
merged with vit branch
vturrisi Jul 8, 2022
1053b73
Merge branch 'main' into h5
vturrisi Jul 8, 2022
580d212
Update dataset.py
vturrisi Jul 8, 2022
87f0d1a
conversion func for h5
vturrisi Jul 11, 2022
cc5488e
fixes
vturrisi Jul 11, 2022
45357a5
fix test
vturrisi Jul 11, 2022
39aa09b
Update utils.py
vturrisi Jul 11, 2022
99eee04
wip
vturrisi Jul 12, 2022
91df454
Update mocov3_vit_h5.sh
vturrisi Jul 12, 2022
14e39ec
Update mocov3_vit_h5.sh
vturrisi Jul 12, 2022
5f8e5e9
Update README.md
DonkeyShot21 Jul 12, 2022
6884efd
Update mocov3_vit_h5.sh
vturrisi Jul 12, 2022
3d5a04e
Merge branch 'fix-readme' of https://github.com/vturrisi/solo-learn i…
vturrisi Jul 13, 2022
0346487
wip
vturrisi Jul 13, 2022
32a5ef8
Merge branch 'h5' of https://github.com/vturrisi/solo-learn into h5
vturrisi Jul 13, 2022
3ec5565
readme
vturrisi Jul 13, 2022
8a1ad24
fix tests?
vturrisi Jul 13, 2022
ff317aa
minor stuff
vturrisi Jul 13, 2022
cdfc1ca
other minor
vturrisi Jul 13, 2022
7053ec3
done?
vturrisi Jul 13, 2022
f511cc0
fixup: format solo with Black
Jul 13, 2022
c197ea6
fix tests
vturrisi Jul 13, 2022
d5d2c09
Merge branch 'h5' of https://github.com/vturrisi/solo-learn into h5
vturrisi Jul 13, 2022
039181a
more fixes
vturrisi Jul 13, 2022
c72f189
wip
vturrisi Jul 13, 2022
d410faf
small tweaks
vturrisi Jul 13, 2022
9375474
fixed linear
vturrisi Jul 13, 2022
a2220ac
typo
vturrisi Jul 13, 2022
75cc370
fixed h5
vturrisi Jul 13, 2022
75a34b0
fix knn and umap scripts
vturrisi Jul 13, 2022
5c31587
small tweaks
vturrisi Jul 13, 2022
6de7de3
Update tests.yml
vturrisi Jul 13, 2022
bbc2b46
Update dali_tests.yml
vturrisi Jul 13, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/dali_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ jobs:
- name: Install Python dependencies
run: |
python -m pip install --upgrade pip
pip install -e .[umap] codecov mypy pytest-cov black
pip install .[dali,umap,h5] --extra-index-url https://developer.download.nvidia.com/compute/redist codecov
pip install mypy pytest-cov black

- name: Cache datasets
uses: actions/cache@v2
Expand All @@ -61,4 +62,4 @@ jobs:
file: coverage.xml
flags: dali
name: DALI-coverage
fail_ci_if_error: false
fail_ci_if_error: false
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
- name: Install Python dependencies
run: |
python -m pip install --upgrade pip
pip install -e .[umap] codecov mypy pytest-cov black
pip install -e .[umap,h5] codecov mypy pytest-cov black

- name: Cache datasets
uses: actions/cache@v2
Expand Down
28 changes: 15 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ The library is self-contained, but it is possible to use the models outside of s
---

## News
* **[Jul 13 2022]**: :sparkling_heart: Added support for [H5](https://docs.h5py.org/en/stable/index.html) data, improved scripts and data handling.
* **[Jun 26 2022]**: :fire: Added [MoCo V3](https://arxiv.org/abs/2104.02057).
* **[Jun 10 2022]**: :bomb: Improved LARS and fixed some issues to support [Horovod](https://horovod.readthedocs.io/en/stable/pytorch.html).
* **[Jun 10 2022]**: :bomb: Improved LARS.
* **[Jun 09 2022]**: :lollipop: Added support for [WideResnet](https://arxiv.org/abs/1605.07146), multicrop for SwAV and equalization data augmentation.
* **[May 02 2022]**: :diamond_shape_with_a_dot_inside: Wrapped Dali with a DataModule, added auto resume for linear eval and Wandb run resume.
* **[Apr 12 2022]**: :rainbow: Improved design of models and added support to train with a fraction of data.
Expand Down Expand Up @@ -65,7 +66,7 @@ The library is self-contained, but it is possible to use the models outside of s

## Extra flavor

### Multiple backbones
### Backbones
* [ResNet](https://arxiv.org/abs/1512.03385)
* [WideResNet](https://arxiv.org/abs/1605.07146)
* [ViT](https://arxiv.org/abs/2010.11929)
Expand All @@ -77,22 +78,23 @@ The library is self-contained, but it is possible to use the models outside of s
* Increased data processing speed by up to 100% using [Nvidia Dali](https://github.com/NVIDIA/DALI).
* Flexible augmentations.

### Evaluation and logging
### Evaluation
* Online linear evaluation via stop-gradient for easier debugging and prototyping (optionally available for the momentum backbone as well).
* Standard offline linear evaluation.
* Online and offline K-NN evaluation.
* Normal offline linear evaluation.
* All the perks of PyTorch Lightning (mixed precision, gradient accumulation, clipping, automatic logging and much more).
* Easy-to-extend modular code structure.
* Custom model logging with a simpler file organization.
* Automatic feature space visualization with UMAP.
* Offline UMAP.
* Common metrics.

### Training tricks
* All the perks of PyTorch Lightning (mixed precision, gradient accumulation, clipping, and much more).
* Channel last conversion
* Multi-cropping dataloading following [SwAV](https://arxiv.org/abs/2006.09882):
* **Note**: currently, only SimCLR, BYOL and SwAV support this.
* Exclude batchnorm and biases from LARS.
* No LR scheduler for the projection head in SimSiam.
* Exclude batchnorm and biases from weight decay and LARS.
* No LR scheduler for the projection head (as in SimSiam).

### Logging
* Metric logging on the cloud with [WandB](https://wandb.ai/site)
* Custom model checkpointing with a simple file organization.

---
## Requirements
Expand Down Expand Up @@ -122,10 +124,10 @@ First clone the repo.

Then, to install solo-learn with [Dali](https://github.com/NVIDIA/DALI) and/or UMAP support, use:
```
pip3 install .[dali,umap] --extra-index-url https://developer.download.nvidia.com/compute/redist
pip3 install .[dali,umap,h5] --extra-index-url https://developer.download.nvidia.com/compute/redist
```

If no Dali/UMAP support is needed, the repository can be installed as:
If no Dali/UMAP/H5 support is needed, the repository can be installed as:
```
pip3 install .
```
Expand Down
2 changes: 1 addition & 1 deletion main_knn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 solo-learn development team.
# Copyright 2022 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
Expand Down
25 changes: 14 additions & 11 deletions main_linear.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 solo-learn development team.
# Copyright 2022 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
Expand Down Expand Up @@ -57,7 +57,7 @@ def main():
if "swin" in args.backbone and cifar:
kwargs["window_size"] = 4

backbone = backbone_model(**kwargs)
backbone = backbone_model(method=None, **kwargs)
if args.backbone.startswith("resnet"):
# remove fc layer
backbone.fc = nn.Identity()
Expand Down Expand Up @@ -90,25 +90,28 @@ def main():
model = LinearModel(backbone, **args.__dict__)
make_contiguous(model)

if args.data_format == "dali":
val_data_format = "image_folder"
else:
val_data_format = args.data_format
train_loader, val_loader = prepare_data(
args.dataset,
data_dir=args.data_dir,
train_dir=args.train_dir,
val_dir=args.val_dir,
train_data_path=args.train_data_path,
val_data_path=args.val_data_path,
data_format=val_data_format,
batch_size=args.batch_size,
num_workers=args.num_workers,
data_fraction=args.data_fraction,
)
if args.dali:

if args.data_format == "dali":
assert (
_dali_avaliable
), "Dali is not currently avaiable, please install it first with [dali]."

dali_datamodule = ClassificationDALIDataModule(
dataset=args.dataset,
data_dir=args.data_dir,
train_dir=args.train_dir,
val_dir=args.val_dir,
train_data_path=args.train_data_path,
val_data_path=args.val_data_path,
num_workers=args.num_workers,
batch_size=args.batch_size,
data_fraction=args.data_fraction,
Expand Down Expand Up @@ -192,7 +195,7 @@ def prefetch_batches(self) -> int:
except:
pass

if args.dali:
if args.data_format == "dali":
trainer.fit(model, ckpt_path=ckpt_path, datamodule=dali_datamodule)
else:
trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt_path)
Expand Down
28 changes: 16 additions & 12 deletions main_pretrain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 solo-learn development team.
# Copyright 2022 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
Expand Down Expand Up @@ -69,28 +69,32 @@ def main():
make_contiguous(model)

# validation dataloader for when it is available
if args.dataset == "custom" and (args.no_labels or args.val_dir is None):
if args.dataset == "custom" and (args.no_labels or args.val_data_path is None):
val_loader = None
elif args.dataset in ["imagenet100", "imagenet"] and args.val_dir is None:
elif args.dataset in ["imagenet100", "imagenet"] and (args.val_data_path is None):
val_loader = None
else:
if args.data_format == "dali":
val_data_format = "image_folder"
else:
val_data_format = args.data_format

_, val_loader = prepare_data_classification(
args.dataset,
data_dir=args.data_dir,
train_dir=args.train_dir,
val_dir=args.val_dir,
train_data_path=args.train_data_path,
val_data_path=args.val_data_path,
data_format=val_data_format,
batch_size=args.batch_size,
num_workers=args.num_workers,
)

# pretrain dataloader
if args.dali:
if args.data_format == "dali":
assert _dali_avaliable, "Dali is not avaiable, please install it first with [dali]."

dali_datamodule = PretrainDALIDataModule(
dataset=args.dataset,
data_dir=args.data_dir,
train_dir=args.train_dir,
train_data_path=args.train_data_path,
unique_augs=args.unique_augs,
transform_kwargs=args.transform_kwargs,
num_crops_per_aug=args.num_crops_per_aug,
Expand Down Expand Up @@ -120,8 +124,8 @@ def main():
train_dataset = prepare_datasets(
args.dataset,
transform,
data_dir=args.data_dir,
train_dir=args.train_dir,
train_data_path=args.train_data_path,
data_format=args.data_format,
no_labels=args.no_labels,
data_fraction=args.data_fraction,
)
Expand Down Expand Up @@ -214,7 +218,7 @@ def prefetch_batches(self) -> int:
except:
pass

if args.dali:
if args.data_format == "dali":
trainer.fit(model, ckpt_path=ckpt_path, datamodule=dali_datamodule)
else:
trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt_path)
Expand Down
2 changes: 1 addition & 1 deletion main_umap.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 solo-learn development team.
# Copyright 2022 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ tqdm
wandb
scipy
timm
scikit-learn
scikit-learn
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
python3 main_knn.py \
--dataset imagenet100 \
--data_dir /datasets \
--train_dir imagenet-100/train \
--val_dir imagenet-100/val \
--train_data_path /datasets/imagenet-100/train \
--val_data_path /datasets/imagenet-100/val \
--batch_size 16 \
--num_workers 10 \
--pretrained_checkpoint_dir PATH \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
python3 main_linear.py \
--dataset imagenet100 \
--backbone resnet18 \
--data_dir /datasets \
--train_dir imagenet-100/train \
--val_dir imagenet-100/val \
--train_data_path /datasets/imagenet-100/train \
--val_data_path /datasets/imagenet-100/val \
--max_epochs 100 \
--devices 0 \
--accelerator gpu \
Expand All @@ -15,11 +14,11 @@ python3 main_linear.py \
--weight_decay 0 \
--batch_size 256 \
--num_workers 4 \
--dali \
--data_format dali \
--name barlow-imagenet100-linear-eval \
--pretrained_feature_extractor PATH \
--project solo-learn \
--entity unitn-mhug \
--wandb \
--save_checkpoint \
--auto_resume
--auto_resume
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
python3 main_linear.py \
--dataset imagenet100 \
--backbone resnet18 \
--data_dir /datasets \
--train_dir imagenet-100/train \
--val_dir imagenet-100/val \
--train_data_path /datasets/imagenet-100/train \
--val_data_path /datasets/imagenet-100/val \
--max_epochs 100 \
--devices 0 \
--accelerator gpu \
Expand All @@ -15,7 +14,7 @@ python3 main_linear.py \
--weight_decay 0 \
--batch_size 256 \
--num_workers 4 \
--dali \
--data_format dali \
--name byol-imagenet100-linear-eval \
--pretrained_feature_extractor PATH \
--project solo-learn \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
python3 main_linear.py \
--dataset imagenet100 \
--backbone resnet18 \
--data_dir /data/datasets \
--train_dir imagenet-100/train \
--val_dir imagenet-100/val \
--train_data_path /datasets/imagenet-100/train \
--val_data_path /datasets/imagenet-100/val \
--max_epochs 100 \
--devices 0 \
--accelerator gpu \
Expand All @@ -15,7 +14,7 @@ python3 main_linear.py \
--weight_decay 0 \
--batch_size 256 \
--num_workers 5 \
--dali \
--data_format dali \
--name deepclusterv2-imagenet100-linear-eval \
--pretrained_feature_extractor PATH --project solo-learn \
--entity unitn-mhug \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
python3 main_linear.py \
--dataset imagenet100 \
--backbone resnet18 \
--data_dir /datasets \
--train_dir imagenet-100/train \
--val_dir imagenet-100/val \
--train_data_path /datasets/imagenet-100/train \
--val_data_path /datasets/imagenet-100/val \
--max_epochs 100 \
--devices 0 \
--accelerator gpu \
Expand All @@ -15,7 +14,7 @@ python3 main_linear.py \
--weight_decay 0 \
--batch_size 256 \
--num_workers 4 \
--dali \
--data_format dali \
--name dino-imagenet100-linear-eval \
--pretrained_feature_extractor PATH \
--project solo-learn \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
python3 main_linear.py \
--dataset imagenet100 \
--backbone resnet18 \
--data_dir /datasets \
--train_dir imagenet-100/train \
--val_dir imagenet-100/val \
--train_data_path /datasets/imagenet-100/train \
--val_data_path /datasets/imagenet-100/val \
--max_epochs 100 \
--devices 0,1 \
--accelerator gpu \
Expand All @@ -17,7 +16,7 @@ python3 main_linear.py \
--weight_decay 0 \
--batch_size 128 \
--num_workers 10 \
--dali \
--data_format dali \
--name method-linear-eval \
--pretrained_feature_extractor PATH \
--project solo-learn \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
python3 main_linear.py \
--dataset imagenet100 \
--backbone resnet18 \
--data_dir /datasets \
--train_dir imagenet-100/train \
--val_dir imagenet-100/val \
--train_data_path /datasets/imagenet-100/train \
--val_data_path /datasets/imagenet-100/val \
--max_epochs 100 \
--devices 0 \
--accelerator gpu \
Expand All @@ -15,7 +14,7 @@ python3 main_linear.py \
--weight_decay 0 \
--batch_size 256 \
--num_workers 10 \
--dali \
--data_format dali \
--name mocov2plus-imagenet100-linear-eval \
--pretrained_feature_extractor PATH \
--project solo-learn \
Expand Down
Loading