From a7ac404e8388c955dca686de352e1aef7820bc3a Mon Sep 17 00:00:00 2001 From: Enrico Fini Date: Tue, 19 Jul 2022 14:21:35 +0200 Subject: [PATCH] fix num steps scheduler when interval is not step (#280) --- solo/methods/base.py | 20 ++++++++---- solo/methods/linear.py | 66 +++++++++++++++++++++++++++------------- tests/args/test_setup.py | 1 - zoo/imagenet100.sh | 6 ++-- 4 files changed, 62 insertions(+), 31 deletions(-) diff --git a/solo/methods/base.py b/solo/methods/base.py index 5ff5842f..dda54bc9 100644 --- a/solo/methods/base.py +++ b/solo/methods/base.py @@ -422,11 +422,21 @@ def configure_optimizers(self) -> Tuple[List, List]: return optimizer if self.scheduler == "warmup_cosine": + max_warmup_steps = ( + self.warmup_epochs * self.num_training_steps + if self.scheduler_interval == "step" + else self.warmup_epochs + ) + max_scheduler_steps = ( + self.max_epochs * self.num_training_steps + if self.scheduler_interval == "step" + else self.max_epochs + ) scheduler = { "scheduler": LinearWarmupCosineAnnealingLR( optimizer, - warmup_epochs=self.warmup_epochs * self.num_training_steps, - max_epochs=self.max_epochs * self.num_training_steps, + warmup_epochs=max_warmup_steps, + max_epochs=max_scheduler_steps, warmup_start_lr=self.warmup_start_lr if self.warmup_epochs > 0 else self.lr, eta_min=self.min_lr, ), @@ -859,11 +869,9 @@ def on_train_batch_end(self, outputs: Dict[str, Any], batch: Sequence[Any], batc # log tau momentum self.log("tau", self.momentum_updater.cur_tau) # update tau - cur_step = self.trainer.global_step - if self.trainer.accumulate_grad_batches: - cur_step = cur_step * self.trainer.accumulate_grad_batches self.momentum_updater.update_tau( - cur_step=cur_step, max_steps=self.max_epochs * self.num_training_steps + cur_step=self.trainer.global_step, + max_steps=self.max_epochs * self.num_training_steps, ) self.last_step = self.trainer.global_step diff --git a/solo/methods/linear.py b/solo/methods/linear.py index dcc7a3b0..182de89b 100644 --- a/solo/methods/linear.py +++ b/solo/methods/linear.py @@ -17,6 +17,7 @@ # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER # DEALINGS IN THE SOFTWARE. +import warnings from argparse import ArgumentParser from typing import Any, Dict, List, Optional, Sequence, Tuple @@ -61,6 +62,7 @@ def __init__( min_lr: float, warmup_start_lr: float, warmup_epochs: float, + scheduler_interval: str = "step", lr_decay_steps: Optional[Sequence[int]] = None, no_channel_last: bool = False, **kwargs, @@ -79,6 +81,7 @@ def __init__( min_lr (float): minimum learning rate for warmup scheduler. warmup_start_lr (float): initial learning rate for warmup scheduler. warmup_epochs (float): number of warmup epochs. + scheduler_interval (str): interval to update the lr scheduler. Defaults to 'step'. lr_decay_steps (Optional[Sequence[int]], optional): list of epochs where the learning rate will be decreased. Defaults to None. no_channel_last (bool). Disables channel last conversion operation which @@ -106,6 +109,8 @@ def __init__( self.min_lr = min_lr self.warmup_start_lr = warmup_start_lr self.warmup_epochs = warmup_epochs + assert scheduler_interval in ["step", "epoch"] + self.scheduler_interval = scheduler_interval self.lr_decay_steps = lr_decay_steps self.no_channel_last = no_channel_last @@ -117,6 +122,12 @@ def __init__( for param in self.backbone.parameters(): param.requires_grad = False + if scheduler_interval == "step": + warnings.warn( + f"Using scheduler_interval={scheduler_interval} might generate " + "issues when resuming a checkpoint." + ) + # can provide up to ~20% speed up if not no_channel_last: self = self.to(memory_format=torch.channels_last) @@ -166,6 +177,9 @@ def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: parser.add_argument("--min_lr", default=0.0, type=float) parser.add_argument("--warmup_start_lr", default=0.003, type=float) parser.add_argument("--warmup_epochs", default=10, type=int) + parser.add_argument( + "--scheduler_interval", choices=["step", "epoch"], default="step", type=str + ) # disables channel last optimization parser.add_argument("--no_channel_last", action="store_true") @@ -211,24 +225,6 @@ def num_training_steps(self) -> int: return self._num_training_steps - def forward(self, X: torch.tensor) -> Dict[str, Any]: - """Performs forward pass of the frozen backbone and the linear layer for evaluation. - - Args: - X (torch.tensor): a batch of images in the tensor format. - - Returns: - Dict[str, Any]: a dict containing features and logits. - """ - - if not self.no_channel_last: - X = X.to(memory_format=torch.channels_last) - - with torch.no_grad(): - feats = self.backbone(X) - logits = self.classifier(feats) - return {"logits": logits, "feats": feats} - def configure_optimizers(self) -> Tuple[List, List]: """Configures the optimizer for the linear layer. @@ -256,15 +252,25 @@ def configure_optimizers(self) -> Tuple[List, List]: return optimizer if self.scheduler == "warmup_cosine": + max_warmup_steps = ( + self.warmup_epochs * self.num_training_steps + if self.scheduler_interval == "step" + else self.warmup_epochs + ) + max_scheduler_steps = ( + self.max_epochs * self.num_training_steps + if self.scheduler_interval == "step" + else self.max_epochs + ) scheduler = { "scheduler": LinearWarmupCosineAnnealingLR( optimizer, - warmup_epochs=self.warmup_epochs * self.num_training_steps, - max_epochs=self.max_epochs * self.num_training_steps, + warmup_epochs=max_warmup_steps, + max_epochs=max_scheduler_steps, warmup_start_lr=self.warmup_start_lr if self.warmup_epochs > 0 else self.lr, eta_min=self.min_lr, ), - "interval": "step", + "interval": self.scheduler_interval, "frequency": 1, } elif self.scheduler == "reduce": @@ -280,6 +286,24 @@ def configure_optimizers(self) -> Tuple[List, List]: return [optimizer], [scheduler] + def forward(self, X: torch.tensor) -> Dict[str, Any]: + """Performs forward pass of the frozen backbone and the linear layer for evaluation. + + Args: + X (torch.tensor): a batch of images in the tensor format. + + Returns: + Dict[str, Any]: a dict containing features and logits. + """ + + if not self.no_channel_last: + X = X.to(memory_format=torch.channels_last) + + with torch.no_grad(): + feats = self.backbone(X) + logits = self.classifier(feats) + return {"logits": logits, "feats": feats} + def shared_step( self, batch: Tuple, batch_idx: int ) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]: diff --git a/tests/args/test_setup.py b/tests/args/test_setup.py index 50db862c..1b80a3f0 100644 --- a/tests/args/test_setup.py +++ b/tests/args/test_setup.py @@ -21,7 +21,6 @@ import os import subprocess import textwrap -from pathlib import Path from solo.args.utils import additional_setup_linear, additional_setup_pretrain from tests.dali.utils import DummyDataset diff --git a/zoo/imagenet100.sh b/zoo/imagenet100.sh index a9e54060..4291a156 100644 --- a/zoo/imagenet100.sh +++ b/zoo/imagenet100.sh @@ -7,7 +7,7 @@ cd imagenet100 mkdir barlow_twins cd barlow_twins gdown https://drive.google.com/uc?id=1C2qQSqp8cXvfrwHVG9MuGTPT2TOTsGla # checkpoint -gdown https://drive.google.com/uc?id=1TY10aa97P4Fl7EgSjTy_u_QME9tkcU4r +gdown https://drive.google.com/uc?id=1TY10aa97P4Fl7EgSjTy_u_QME9tkcU4r # args cd .. # BYOL @@ -48,8 +48,8 @@ cd .. # MoCo V3 mkdir mocov3 cd mocov3 -gdown https://drive.google.com/file/d/1cUaAdx-6NXCkeSMo-mQtpPnYk7zA4Gg4/view?usp=sharing # checkpoint -gdown https://drive.google.com/file/d/1mb6ZRKF1CdGP0rdJI2yjyStZ-FCFjsi4/view?usp=sharing # args +gdown https://drive.google.com/uc?id=1cUaAdx-6NXCkeSMo-mQtpPnYk7zA4Gg4 # checkpoint +gdown https://drive.google.com/uc?id=1mb6ZRKF1CdGP0rdJI2yjyStZ-FCFjsi4 # args cd .. # NNCLR