Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Mocov3 #268

Merged
merged 14 commits into from
Jun 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The library is self-contained, but it is possible to use the models outside of s
---

## News
* **[Jun 26 2022]**: :fire: Added [MoCo V3](https://arxiv.org/abs/2104.02057).
* **[Jun 10 2022]**: :bomb: Improved LARS and fixed some issues to support [Horovod](https://horovod.readthedocs.io/en/stable/pytorch.html).
* **[Jun 09 2022]**: :lollipop: Added support for [WideResnet](https://arxiv.org/abs/1605.07146), multicrop for SwAV and equalization data augmentation.
* **[May 02 2022]**: :diamond_shape_with_a_dot_inside: Wrapped Dali with a DataModule, added auto resume for linear eval and Wandb run resume.
Expand Down Expand Up @@ -47,6 +48,7 @@ The library is self-contained, but it is possible to use the models outside of s
* [DeepCluster V2](https://arxiv.org/abs/2006.09882)
* [DINO](https://arxiv.org/abs/2104.14294)
* [MoCo V2+](https://arxiv.org/abs/2003.04297)
* [MoCo V3](https://arxiv.org/abs/2104.02057)
* [NNBYOL](https://arxiv.org/abs/2104.14548)
* [NNCLR](https://arxiv.org/abs/2104.14548)
* [NNSiam](https://arxiv.org/abs/2104.14548)
Expand Down Expand Up @@ -177,11 +179,12 @@ All pretrained models avaiable can be downloaded directly via the tables below o

| Method | Backbone | Epochs | Dali | Acc@1 | Acc@5 | Checkpoint |
|--------------|:--------:|:------:|:----:|:--------------:|:--------------:|:----------:|
| Barlow Twins | ResNet18 | 1000 | :x: | 92.10 | 99.73 | [:link:](https://drive.google.com/drive/folders/1L5RAM3lCSViD2zEqLtC-GQKVw6mxtxJ_?usp=sharing) |
| Barlow Twins | ResNet18 | 1000 | :x: | 92.10 | 99.73 | [:link:](https://drive.google.com/drive/folders/1L5RAM3lCSViD2zEqLtC-GQKVw6mxtxJ_?usp=sharing) |
| BYOL | ResNet18 | 1000 | :x: | 92.58 | 99.79 | [:link:](https://drive.google.com/drive/folders/1KxeYAEE7Ev9kdFFhXWkPZhG-ya3_UwGP?usp=sharing) |
|DeepCluster V2| ResNet18 | 1000 | :x: | 88.85 | 99.58 | [:link:](https://drive.google.com/drive/folders/1tkEbiDQ38vZaQUsT6_vEpxbDxSUAGwF-?usp=sharing) |
| DINO | ResNet18 | 1000 | :x: | 89.52 | 99.71 | [:link:](https://drive.google.com/drive/folders/1vyqZKUyP8sQyEyf2cqonxlGMbQC-D1Gi?usp=sharing) |
| MoCo V2+ | ResNet18 | 1000 | :x: | 92.94 | 99.79 | [:link:](https://drive.google.com/drive/folders/1ruNFEB3F-Otxv2Y0p62wrjA4v5Fr2cKC?usp=sharing) |
| MoCo V3 | ResNet18 | 1000 | :x: | 93.10 | 99.80 | [:link:](https://drive.google.com/drive/folders/1KwZTshNEpmqnYJcmyYPvfIJ_DNwqtAVj?usp=sharing) |
| NNCLR | ResNet18 | 1000 | :x: | 91.88 | 99.78 | [:link:](https://drive.google.com/drive/folders/1xdCzhvRehPmxinphuiZqFlfBwfwWDcLh?usp=sharing) |
| ReSSL | ResNet18 | 1000 | :x: | 90.63 | 99.62 | [:link:](https://drive.google.com/drive/folders/1jrFcztY2eO_fG98xPshqOD15pDIhLXp-?usp=sharing) |
| SimCLR | ResNet18 | 1000 | :x: | 90.74 | 99.75 | [:link:](https://drive.google.com/drive/folders/1mcvWr8P2WNJZ7TVpdLHA_Q91q4VK3y8O?usp=sharing) |
Expand All @@ -202,6 +205,7 @@ All pretrained models avaiable can be downloaded directly via the tables below o
|DeepCluster V2| ResNet18 | 1000 | :x: | 63.61 | 88.09 | [:link:](https://drive.google.com/drive/folders/1gAKyMz41mvGh1BBOYdc_xu6JPSkKlWqK?usp=sharing) |
| DINO | ResNet18 | 1000 | :x: | 66.76 | 90.34 | [:link:](https://drive.google.com/drive/folders/1TxeZi2YLprDDtbt_y5m29t4euroWr1Fy?usp=sharing) |
| MoCo V2+ | ResNet18 | 1000 | :x: | 69.89 | 91.65 | [:link:](https://drive.google.com/drive/folders/15oWNM16vO6YVYmk_yOmw2XUrFivRXam4?usp=sharing) |
| MoCo V3 | ResNet18 | 1000 | :x: | 68.83 | 90.57 | [:link:](https://drive.google.com/drive/folders/1Hcf9kMIADKydfxvXLquY9nv7sfNaJ3v6?usp=sharing) |
| NNCLR | ResNet18 | 1000 | :x: | 69.62 | 91.52 | [:link:](https://drive.google.com/drive/folders/1Dz72o0-5hugYPW1kCCQDBb0Xi3kzMLzu?usp=sharing) |
| ReSSL | ResNet18 | 1000 | :x: | 65.92 | 89.73 | [:link:](https://drive.google.com/drive/folders/1aVZs9cHAu6Ccz8ILyWkp6NhTsJGBGfjr?usp=sharing) |
| SimCLR | ResNet18 | 1000 | :x: | 65.78 | 89.04 | [:link:](https://drive.google.com/drive/folders/13pGPcOO9Y3rBoeRVWARgbMFEp8OXxZa0?usp=sharing) |
Expand All @@ -222,6 +226,7 @@ All pretrained models avaiable can be downloaded directly via the tables below o
| DINO | ResNet18 | 400 | :heavy_check_mark: | 74.84 | 74.92 | 92.92 | 92.78 | [:link:](https://drive.google.com/drive/folders/1NtVvRj-tQJvrMxRlMtCJSAecQnYZYkqs?usp=sharing) |
| DINO :sleepy: | ViT Tiny | 400 | :x: | 63.04 | TODO | 87.72 | TODO | [:link:](https://drive.google.com/drive/folders/16AfsM-UpKky43kdSMlqj4XRe69pRdJLc?usp=sharing) |
| MoCo V2+ :rocket: | ResNet18 | 400 | :heavy_check_mark: | 78.20 | 79.28 | 95.50 | 95.18 | [:link:](https://drive.google.com/drive/folders/1ItYBtMJ23Yh-Rhrvwjm4w1waFfUGSoKX?usp=sharing) |
| MoCo V3 :rocket: | ResNet18 | 400 | :heavy_check_mark: | 80.36 | 80.36 | 95.18 | 94.96 | [:link:](https://drive.google.com/drive/folders/15J0JiZsQAsrQler8mbbio-desb_nVoD1?usp=sharing) |
| NNCLR :rocket: | ResNet18 | 400 | :heavy_check_mark: | 79.80 | 80.16 | 95.28 | 95.30 | [:link:](https://drive.google.com/drive/folders/1QMkq8w3UsdcZmoNUIUPgfSCAZl_LSNjZ?usp=sharing) |
| ReSSL | ResNet18 | 400 | :heavy_check_mark: | 76.92 | 78.48 | 94.20 | 94.24 | [:link:](https://drive.google.com/drive/folders/1urWIFACLont4GAduis6l0jcEbl080c9U?usp=sharing) |
| SimCLR :rocket: | ResNet18 | 400 | :heavy_check_mark: | 77.64 | TODO | 94.06 | TODO | [:link:](https://drive.google.com/drive/folders/1yxAVKnc8Vf0tDfkixSB5mXe7dsA8Ll37?usp=sharing) |
Expand All @@ -245,6 +250,7 @@ All pretrained models avaiable can be downloaded directly via the tables below o
|DeepCluster V2| ResNet50 | 100 | :heavy_check_mark: | | | | | |
| DINO | ResNet50 | 100 | :heavy_check_mark: | | | | | |
| MoCo V2+ | ResNet50 | 100 | :heavy_check_mark: | 62.61 | 66.84 | 85.40 | 87.60 | [:link:](https://drive.google.com/drive/folders/1NiBDmieEpNqkwrgn_H7bMnEDVAYc8Sk7?usp=sharing) |
| MoCo V3 | ResNet50 | 100 | :heavy_check_mark: | | | | | |
| NNCLR | ResNet50 | 100 | :heavy_check_mark: | | | | | |
| ReSSL | ResNet50 | 100 | :heavy_check_mark: | | | | | |
| SimCLR | ResNet50 | 100 | :heavy_check_mark: | | | | | |
Expand Down
6 changes: 4 additions & 2 deletions bash_files/linear/imagenet-100/general_linear.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ python3 main_linear.py \
--weight_decay 0 \
--batch_size 128 \
--num_workers 10 \
--name simclr-linear-eval \
--dali \
--name method-linear-eval \
--pretrained_feature_extractor PATH \
--project contrastive_learning \
--project solo-learn \
--entity unitn-mhug \
--wandb \
--save_checkpoint \
--auto_resume
26 changes: 26 additions & 0 deletions bash_files/linear/imagenet-100/mocov3_linear.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
python3 main_linear.py \
--dataset imagenet100 \
--backbone resnet18 \
--data_dir /datasets \
--train_dir imagenet-100/train \
--val_dir imagenet-100/val \
--max_epochs 100 \
--devices 0,1 \
--accelerator gpu \
--strategy ddp \
--sync_batchnorm \
--precision 16 \
--optimizer sgd \
--scheduler warmup_cosine \
--warmup_epochs 0 \
--lr 0.3 \
--weight_decay 0 \
--batch_size 128 \
--num_workers 10 \
--dali \
--name mocov3-linear-eval \
--pretrained_feature_extractor PATH \
--project solo-learn \
--entity unitn-mhug \
--wandb \
--save_checkpoint
38 changes: 38 additions & 0 deletions bash_files/pretrain/cifar/mocov3.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
python3 main_pretrain.py \
--dataset $1 \
--backbone resnet18 \
--data_dir ./datasets \
--max_epochs 1000 \
--devices 0 \
--accelerator gpu \
--precision 16 \
--optimizer lars \
--eta_lars 0.02 \
--exclude_bias_n_norm \
--scheduler warmup_cosine \
--lr 0.3 \
--classifier_lr 0.3 \
--weight_decay 1e-6 \
--batch_size 256 \
--num_workers 4 \
--brightness 0.4 \
--contrast 0.4 \
--saturation 0.4 \
--hue 0.1 \
--gaussian_prob 1.0 0.1 \
--solarization_prob 0.0 0.2 \
--min_scale 0.2 \
--crop_size 32 \
--num_crops_per_aug 1 1 \
--name mocov3-$1 \
--project solo-learn \
--entity unitn-mhug \
--wandb \
--save_checkpoint \
--auto_resume \
--method mocov3 \
--proj_hidden_dim 4096 \
--pred_hidden_dim 4096 \
--temperature 0.2 \
--base_tau_momentum 0.99 \
--final_tau_momentum 1.0
42 changes: 42 additions & 0 deletions bash_files/pretrain/imagenet-100/mocov3.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
python3 main_pretrain.py \
--dataset imagenet100 \
--backbone resnet18 \
--data_dir /datasets \
--train_dir imagenet-100/train \
--val_dir imagenet-100/val \
--max_epochs 400 \
--devices 0,1 \
--accelerator gpu \
--strategy ddp \
--sync_batchnorm \
--precision 16 \
--optimizer lars \
--eta_lars 0.02 \
--exclude_bias_n_norm \
--scheduler warmup_cosine \
--lr 0.3 \
--classifier_lr 0.3 \
--weight_decay 1e-6 \
--batch_size 128 \
--num_workers 4 \
--dali \
--brightness 0.4 \
--contrast 0.4 \
--saturation 0.2 \
--hue 0.1 \
--gaussian_prob 1.0 0.1 \
--solarization_prob 0.0 0.2 \
--min_scale 0.2 \
--num_crops_per_aug 1 1 \
--name mocov3-400ep-imagenet100 \
--project solo-learn \
--entity unitn-mhug \
--save_checkpoint \
--wandb \
--auto_resume \
--method mocov3 \
--proj_hidden_dim 4096 \
--pred_hidden_dim 4096 \
--temperature 0.2 \
--base_tau_momentum 0.99 \
--final_tau_momentum 1.0
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ While the library is self contained, it is possible to use the models outside of
solo/methods/deepclusterv2
solo/methods/dino
solo/methods/mocov2plus
solo/methods/mocov3
solo/methods/nnbyol
solo/methods/nnclr
solo/methods/nnsiam
Expand All @@ -91,6 +92,7 @@ While the library is self contained, it is possible to use the models outside of
solo/losses/deepclusterv2
solo/losses/dino
solo/losses/mocov2plus
solo/losses/mocov3
solo/losses/nnclr
solo/losses/ressl
solo/losses/simclr
Expand Down
2 changes: 1 addition & 1 deletion docs/source/solo/losses/mocov2plus.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
MoCo-V2
-------

.. autofunction:: solo.losses.moco.moco_loss_func
.. autofunction:: solo.losses.mocov2plus.mocov2plus_loss_func
:noindex:
5 changes: 5 additions & 0 deletions docs/source/solo/losses/mocov3.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
MoCo-V3
-------

.. autofunction:: solo.losses.mocov3.mocov3_loss_func
:noindex:
31 changes: 31 additions & 0 deletions docs/source/solo/methods/mocov3.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
MoCo-V3
=======


.. automethod:: solo.methods.mocov3.MoCoV3.__init__
:noindex:

add_model_specific_args
~~~~~~~~~~~~~~~~~~~~~~~
.. automethod:: solo.solo.methods.mocov3.MoCoV3.add_model_specific_args
:noindex:

learnable_params
~~~~~~~~~~~~~~~~
.. autoattribute:: solo.solo.methods.mocov3.MoCoV3.learnable_params
:noindex:

momentum_pairs
~~~~~~~~~~~~~~
.. autoattribute:: solo.solo.methods.mocov3.MoCoV3.momentum_pairs
:noindex:

forward
~~~~~~~
.. automethod:: solo.solo.methods.mocov3.MoCoV3.forward
:noindex:

training_step
~~~~~~~~~~~~~
.. automethod:: solo.solo.methods.mocov3.MoCoV3.training_step
:noindex:
4 changes: 2 additions & 2 deletions main_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
def main():
args = parse_args_linear()

assert args.backbone in BaseMethod._SUPPORTED_BACKBONES
assert args.backbone in BaseMethod._BACKBONES
backbone_model = {
"resnet18": resnet18,
"resnet50": resnet50,
Expand Down Expand Up @@ -94,7 +94,7 @@ def main():
)
ckpt_path = args.pretrained_feature_extractor

state = torch.load(ckpt_path)["state_dict"]
state = torch.load(ckpt_path, map_location="cpu")["state_dict"]
for k in list(state.keys()):
if "encoder" in k:
state[k.replace("encoder", "backbone")] = state[k]
Expand Down
6 changes: 4 additions & 2 deletions solo/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from solo.losses.byol import byol_loss_func
from solo.losses.deepclusterv2 import deepclusterv2_loss_func
from solo.losses.dino import DINOLoss
from solo.losses.moco import moco_loss_func
from solo.losses.mocov2plus import mocov2plus_loss_func
from solo.losses.mocov3 import mocov3_loss_func
from solo.losses.nnclr import nnclr_loss_func
from solo.losses.ressl import ressl_loss_func
from solo.losses.simclr import simclr_loss_func
Expand All @@ -36,7 +37,8 @@
"byol_loss_func",
"deepclusterv2_loss_func",
"DINOLoss",
"moco_loss_func",
"mocov2plus_loss_func",
"mocov3_loss_func",
"nnclr_loss_func",
"ressl_loss_func",
"simclr_loss_func",
Expand Down
6 changes: 3 additions & 3 deletions solo/losses/moco.py → solo/losses/mocov2plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@
import torch.nn.functional as F


def moco_loss_func(
def mocov2plus_loss_func(
query: torch.Tensor, key: torch.Tensor, queue: torch.Tensor, temperature=0.1
) -> torch.Tensor:
"""Computes MoCo's loss given a batch of queries from view 1, a batch of keys from view 2 and a
queue of past elements.

Args:
query (torch.Tensor): NxD Tensor containing the queries from view 1.
key (torch.Tensor): NxD Tensor containing the queries from view 2.
key (torch.Tensor): NxD Tensor containing the keys from view 2.
queue (torch.Tensor): a queue of negative samples for the contrastive loss.
temperature (float, optional): [description]. temperature of the softmax in the contrastive
temperature (float, optional): temperature of the softmax in the contrastive
loss. Defaults to 0.1.

Returns:
Expand Down
70 changes: 70 additions & 0 deletions solo/losses/mocov3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2021 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

import torch
import torch.distributed as dist
import torch.nn.functional as F


@torch.no_grad()
def concat_all_gather_no_grad(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""

if dist.is_available() and dist.is_initialized():
tensors_gather = [
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

output = torch.cat(tensors_gather, dim=0)
return output
return tensor


def mocov3_loss_func(query: torch.Tensor, key: torch.Tensor, temperature=0.2) -> torch.Tensor:
"""Computes MoCo V3's loss given a batch of queries from view 1, a batch of keys from view 2 and a
queue of past elements.

Args:
query (torch.Tensor): NxD Tensor containing the queries from view 1.
key (torch.Tensor): NxD Tensor containing the keys from view 2.
temperature (float, optional): temperature of the softmax in the contrastive
loss. Defaults to 0.2.

Returns:
torch.Tensor: MoCo loss.
"""

n = query.size(0)
device = query.device
rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0

query = F.normalize(query, dim=1)
key = F.normalize(key, dim=1)

# gather all targets without gradients
key = concat_all_gather_no_grad(key)

logits = torch.einsum("nc,mc->nm", [query, key]) / temperature
labels = torch.arange(n, dtype=torch.long, device=device) + n * rank

return F.cross_entropy(logits, labels) * (2 * temperature)
3 changes: 3 additions & 0 deletions solo/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from solo.methods.dino import DINO
from solo.methods.linear import LinearModel
from solo.methods.mocov2plus import MoCoV2Plus
from solo.methods.mocov3 import MoCoV3
from solo.methods.nnbyol import NNBYOL
from solo.methods.nnclr import NNCLR
from solo.methods.nnsiam import NNSiam
Expand All @@ -46,6 +47,7 @@
"deepclusterv2": DeepClusterV2,
"dino": DINO,
"mocov2plus": MoCoV2Plus,
"mocov3": MoCoV3,
"nnbyol": NNBYOL,
"nnclr": NNCLR,
"nnsiam": NNSiam,
Expand All @@ -66,6 +68,7 @@
"DINO",
"LinearModel",
"MoCoV2Plus",
"MoCoV3",
"NNBYOL",
"NNCLR",
"NNSiam",
Expand Down
Loading