Skip to content

Commit

Permalink
Add MosaicML's ChannelLast optization (#235)
Browse files Browse the repository at this point in the history
  • Loading branch information
vturrisi authored Apr 1, 2022
1 parent f25dc1f commit 467a656
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ While the library is self-contained, it is possible to use the models outside of
---

## News
* **[Apr 01 2022]**: :mag: Added [MosaicML's](https://github.com/mosaicml/composer) ChannelLast operation which considerably decreases training times.
* **[Feb 04 2022]**: :partying_face: Paper got accepted to JMLR.
* **[Jan 31 2022]**: :eye: Added ConvNeXt support with timm.
* **[Dec 20 2021]**: :thermometer: Added ImageNet results, scripts and checkpoints for MoCo V2+.
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ tqdm
wandb
scipy
timm
mosaicml
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def parse_requirements(path):
"wandb",
"scipy",
"timm",
"mosaicml",
],
extras_require=EXTRA_REQUIREMENTS,
dependency_links=["https://developer.download.nvidia.com/compute/redist"],
Expand Down
14 changes: 14 additions & 0 deletions solo/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from functools import partial
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union

import composer.functional as cf
import pytorch_lightning as pl
import torch
import torch.nn as nn
Expand Down Expand Up @@ -115,6 +116,7 @@ def __init__(
lr_decay_steps: Sequence = None,
knn_eval: bool = False,
knn_k: int = 20,
no_mosaicml_channel_last=False,
**kwargs,
):
"""Base model that implements all basic operations for all self-supervised methods.
Expand Down Expand Up @@ -157,6 +159,9 @@ def __init__(
step. Defaults to None.
knn_eval (bool): enables online knn evaluation while training.
knn_k (int): the number of neighbors to use for knn.
no_mosaicml_channel_last (bool). Disables MosaicML ChannelLast operation which
speeds up training considerably (https://github.com/mosaicml/composer).
Defaults to False.
.. note::
When using distributed data parallel, the batch size and the number of workers are
Expand Down Expand Up @@ -258,6 +263,11 @@ def __init__(
"issues when resuming a checkpoint."
)

# https://docs.mosaicml.com/en/v0.5.0/method_cards/channels_last.html
# can provide up to ~20% speed up
if not no_mosaicml_channel_last:
cf.apply_channels_last(self)

@staticmethod
def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
"""Adds shared basic arguments that are shared for all methods.
Expand Down Expand Up @@ -331,6 +341,10 @@ def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
parser.add_argument("--knn_eval", action="store_true")
parser.add_argument("--knn_k", default=20, type=int)

# mosaicml optimization
# disables mosaicml channel last optization
parser.add_argument("--no_mosaicml_channel_last", action="store_true")

return parent_parser

def set_loaders(self, train_loader: DataLoader = None, val_loader: DataLoader = None) -> None:
Expand Down

0 comments on commit 467a656

Please sign in to comment.