Skip to content

Commit

Permalink
fix num steps scheduler when interval is not step (#280)
Browse files Browse the repository at this point in the history
  • Loading branch information
DonkeyShot21 authored Jul 19, 2022
1 parent 2deab45 commit a7ac404
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 31 deletions.
20 changes: 14 additions & 6 deletions solo/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,11 +422,21 @@ def configure_optimizers(self) -> Tuple[List, List]:
return optimizer

if self.scheduler == "warmup_cosine":
max_warmup_steps = (
self.warmup_epochs * self.num_training_steps
if self.scheduler_interval == "step"
else self.warmup_epochs
)
max_scheduler_steps = (
self.max_epochs * self.num_training_steps
if self.scheduler_interval == "step"
else self.max_epochs
)
scheduler = {
"scheduler": LinearWarmupCosineAnnealingLR(
optimizer,
warmup_epochs=self.warmup_epochs * self.num_training_steps,
max_epochs=self.max_epochs * self.num_training_steps,
warmup_epochs=max_warmup_steps,
max_epochs=max_scheduler_steps,
warmup_start_lr=self.warmup_start_lr if self.warmup_epochs > 0 else self.lr,
eta_min=self.min_lr,
),
Expand Down Expand Up @@ -859,11 +869,9 @@ def on_train_batch_end(self, outputs: Dict[str, Any], batch: Sequence[Any], batc
# log tau momentum
self.log("tau", self.momentum_updater.cur_tau)
# update tau
cur_step = self.trainer.global_step
if self.trainer.accumulate_grad_batches:
cur_step = cur_step * self.trainer.accumulate_grad_batches
self.momentum_updater.update_tau(
cur_step=cur_step, max_steps=self.max_epochs * self.num_training_steps
cur_step=self.trainer.global_step,
max_steps=self.max_epochs * self.num_training_steps,
)
self.last_step = self.trainer.global_step

Expand Down
66 changes: 45 additions & 21 deletions solo/methods/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

import warnings
from argparse import ArgumentParser
from typing import Any, Dict, List, Optional, Sequence, Tuple

Expand Down Expand Up @@ -61,6 +62,7 @@ def __init__(
min_lr: float,
warmup_start_lr: float,
warmup_epochs: float,
scheduler_interval: str = "step",
lr_decay_steps: Optional[Sequence[int]] = None,
no_channel_last: bool = False,
**kwargs,
Expand All @@ -79,6 +81,7 @@ def __init__(
min_lr (float): minimum learning rate for warmup scheduler.
warmup_start_lr (float): initial learning rate for warmup scheduler.
warmup_epochs (float): number of warmup epochs.
scheduler_interval (str): interval to update the lr scheduler. Defaults to 'step'.
lr_decay_steps (Optional[Sequence[int]], optional): list of epochs where the learning
rate will be decreased. Defaults to None.
no_channel_last (bool). Disables channel last conversion operation which
Expand Down Expand Up @@ -106,6 +109,8 @@ def __init__(
self.min_lr = min_lr
self.warmup_start_lr = warmup_start_lr
self.warmup_epochs = warmup_epochs
assert scheduler_interval in ["step", "epoch"]
self.scheduler_interval = scheduler_interval
self.lr_decay_steps = lr_decay_steps
self.no_channel_last = no_channel_last

Expand All @@ -117,6 +122,12 @@ def __init__(
for param in self.backbone.parameters():
param.requires_grad = False

if scheduler_interval == "step":
warnings.warn(
f"Using scheduler_interval={scheduler_interval} might generate "
"issues when resuming a checkpoint."
)

# can provide up to ~20% speed up
if not no_channel_last:
self = self.to(memory_format=torch.channels_last)
Expand Down Expand Up @@ -166,6 +177,9 @@ def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
parser.add_argument("--min_lr", default=0.0, type=float)
parser.add_argument("--warmup_start_lr", default=0.003, type=float)
parser.add_argument("--warmup_epochs", default=10, type=int)
parser.add_argument(
"--scheduler_interval", choices=["step", "epoch"], default="step", type=str
)

# disables channel last optimization
parser.add_argument("--no_channel_last", action="store_true")
Expand Down Expand Up @@ -211,24 +225,6 @@ def num_training_steps(self) -> int:

return self._num_training_steps

def forward(self, X: torch.tensor) -> Dict[str, Any]:
"""Performs forward pass of the frozen backbone and the linear layer for evaluation.
Args:
X (torch.tensor): a batch of images in the tensor format.
Returns:
Dict[str, Any]: a dict containing features and logits.
"""

if not self.no_channel_last:
X = X.to(memory_format=torch.channels_last)

with torch.no_grad():
feats = self.backbone(X)
logits = self.classifier(feats)
return {"logits": logits, "feats": feats}

def configure_optimizers(self) -> Tuple[List, List]:
"""Configures the optimizer for the linear layer.
Expand Down Expand Up @@ -256,15 +252,25 @@ def configure_optimizers(self) -> Tuple[List, List]:
return optimizer

if self.scheduler == "warmup_cosine":
max_warmup_steps = (
self.warmup_epochs * self.num_training_steps
if self.scheduler_interval == "step"
else self.warmup_epochs
)
max_scheduler_steps = (
self.max_epochs * self.num_training_steps
if self.scheduler_interval == "step"
else self.max_epochs
)
scheduler = {
"scheduler": LinearWarmupCosineAnnealingLR(
optimizer,
warmup_epochs=self.warmup_epochs * self.num_training_steps,
max_epochs=self.max_epochs * self.num_training_steps,
warmup_epochs=max_warmup_steps,
max_epochs=max_scheduler_steps,
warmup_start_lr=self.warmup_start_lr if self.warmup_epochs > 0 else self.lr,
eta_min=self.min_lr,
),
"interval": "step",
"interval": self.scheduler_interval,
"frequency": 1,
}
elif self.scheduler == "reduce":
Expand All @@ -280,6 +286,24 @@ def configure_optimizers(self) -> Tuple[List, List]:

return [optimizer], [scheduler]

def forward(self, X: torch.tensor) -> Dict[str, Any]:
"""Performs forward pass of the frozen backbone and the linear layer for evaluation.
Args:
X (torch.tensor): a batch of images in the tensor format.
Returns:
Dict[str, Any]: a dict containing features and logits.
"""

if not self.no_channel_last:
X = X.to(memory_format=torch.channels_last)

with torch.no_grad():
feats = self.backbone(X)
logits = self.classifier(feats)
return {"logits": logits, "feats": feats}

def shared_step(
self, batch: Tuple, batch_idx: int
) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down
1 change: 0 additions & 1 deletion tests/args/test_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import os
import subprocess
import textwrap
from pathlib import Path

from solo.args.utils import additional_setup_linear, additional_setup_pretrain
from tests.dali.utils import DummyDataset
Expand Down
6 changes: 3 additions & 3 deletions zoo/imagenet100.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ cd imagenet100
mkdir barlow_twins
cd barlow_twins
gdown https://drive.google.com/uc?id=1C2qQSqp8cXvfrwHVG9MuGTPT2TOTsGla # checkpoint
gdown https://drive.google.com/uc?id=1TY10aa97P4Fl7EgSjTy_u_QME9tkcU4r
gdown https://drive.google.com/uc?id=1TY10aa97P4Fl7EgSjTy_u_QME9tkcU4r # args
cd ..

# BYOL
Expand Down Expand Up @@ -48,8 +48,8 @@ cd ..
# MoCo V3
mkdir mocov3
cd mocov3
gdown https://drive.google.com/file/d/1cUaAdx-6NXCkeSMo-mQtpPnYk7zA4Gg4/view?usp=sharing # checkpoint
gdown https://drive.google.com/file/d/1mb6ZRKF1CdGP0rdJI2yjyStZ-FCFjsi4/view?usp=sharing # args
gdown https://drive.google.com/uc?id=1cUaAdx-6NXCkeSMo-mQtpPnYk7zA4Gg4 # checkpoint
gdown https://drive.google.com/uc?id=1mb6ZRKF1CdGP0rdJI2yjyStZ-FCFjsi4 # args
cd ..

# NNCLR
Expand Down

0 comments on commit a7ac404

Please sign in to comment.