Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix num steps scheduler when interval is not step #280

Merged
merged 7 commits into from
Jul 19, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions solo/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,11 +422,21 @@ def configure_optimizers(self) -> Tuple[List, List]:
return optimizer

if self.scheduler == "warmup_cosine":
max_warmup_steps = (
self.warmup_epochs * self.num_training_steps
if self.scheduler_interval == "step"
else self.warmup_epochs
)
max_scheduler_steps = (
self.max_epochs * self.num_training_steps
if self.scheduler_interval == "step"
else self.max_epochs
)
scheduler = {
"scheduler": LinearWarmupCosineAnnealingLR(
optimizer,
warmup_epochs=self.warmup_epochs * self.num_training_steps,
max_epochs=self.max_epochs * self.num_training_steps,
warmup_epochs=max_warmup_steps,
max_epochs=max_scheduler_steps,
warmup_start_lr=self.warmup_start_lr if self.warmup_epochs > 0 else self.lr,
eta_min=self.min_lr,
),
Expand Down Expand Up @@ -859,11 +869,9 @@ def on_train_batch_end(self, outputs: Dict[str, Any], batch: Sequence[Any], batc
# log tau momentum
self.log("tau", self.momentum_updater.cur_tau)
# update tau
cur_step = self.trainer.global_step
if self.trainer.accumulate_grad_batches:
cur_step = cur_step * self.trainer.accumulate_grad_batches
self.momentum_updater.update_tau(
cur_step=cur_step, max_steps=self.max_epochs * self.num_training_steps
cur_step=self.trainer.global_step,
max_steps=self.max_epochs * self.num_training_steps,
)
self.last_step = self.trainer.global_step

Expand Down
66 changes: 45 additions & 21 deletions solo/methods/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

import warnings
from argparse import ArgumentParser
from typing import Any, Dict, List, Optional, Sequence, Tuple

Expand Down Expand Up @@ -61,6 +62,7 @@ def __init__(
min_lr: float,
warmup_start_lr: float,
warmup_epochs: float,
scheduler_interval: str = "step",
lr_decay_steps: Optional[Sequence[int]] = None,
no_channel_last: bool = False,
**kwargs,
Expand All @@ -79,6 +81,7 @@ def __init__(
min_lr (float): minimum learning rate for warmup scheduler.
warmup_start_lr (float): initial learning rate for warmup scheduler.
warmup_epochs (float): number of warmup epochs.
scheduler_interval (str): interval to update the lr scheduler. Defaults to 'step'.
lr_decay_steps (Optional[Sequence[int]], optional): list of epochs where the learning
rate will be decreased. Defaults to None.
no_channel_last (bool). Disables channel last conversion operation which
Expand Down Expand Up @@ -106,6 +109,8 @@ def __init__(
self.min_lr = min_lr
self.warmup_start_lr = warmup_start_lr
self.warmup_epochs = warmup_epochs
assert scheduler_interval in ["step", "epoch"]
self.scheduler_interval = scheduler_interval
self.lr_decay_steps = lr_decay_steps
self.no_channel_last = no_channel_last

Expand All @@ -117,6 +122,12 @@ def __init__(
for param in self.backbone.parameters():
param.requires_grad = False

if scheduler_interval == "step":
warnings.warn(
f"Using scheduler_interval={scheduler_interval} might generate "
"issues when resuming a checkpoint."
)

# can provide up to ~20% speed up
if not no_channel_last:
self = self.to(memory_format=torch.channels_last)
Expand Down Expand Up @@ -166,6 +177,9 @@ def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
parser.add_argument("--min_lr", default=0.0, type=float)
parser.add_argument("--warmup_start_lr", default=0.003, type=float)
parser.add_argument("--warmup_epochs", default=10, type=int)
parser.add_argument(
"--scheduler_interval", choices=["step", "epoch"], default="step", type=str
)

# disables channel last optimization
parser.add_argument("--no_channel_last", action="store_true")
Expand Down Expand Up @@ -211,24 +225,6 @@ def num_training_steps(self) -> int:

return self._num_training_steps

def forward(self, X: torch.tensor) -> Dict[str, Any]:
"""Performs forward pass of the frozen backbone and the linear layer for evaluation.
Args:
X (torch.tensor): a batch of images in the tensor format.
Returns:
Dict[str, Any]: a dict containing features and logits.
"""

if not self.no_channel_last:
X = X.to(memory_format=torch.channels_last)

with torch.no_grad():
feats = self.backbone(X)
logits = self.classifier(feats)
return {"logits": logits, "feats": feats}

def configure_optimizers(self) -> Tuple[List, List]:
"""Configures the optimizer for the linear layer.
Expand Down Expand Up @@ -256,15 +252,25 @@ def configure_optimizers(self) -> Tuple[List, List]:
return optimizer

if self.scheduler == "warmup_cosine":
max_warmup_steps = (
self.warmup_epochs * self.num_training_steps
if self.scheduler_interval == "step"
else self.warmup_epochs
)
max_scheduler_steps = (
self.max_epochs * self.num_training_steps
if self.scheduler_interval == "step"
else self.max_epochs
)
scheduler = {
"scheduler": LinearWarmupCosineAnnealingLR(
optimizer,
warmup_epochs=self.warmup_epochs * self.num_training_steps,
max_epochs=self.max_epochs * self.num_training_steps,
warmup_epochs=max_warmup_steps,
max_epochs=max_scheduler_steps,
warmup_start_lr=self.warmup_start_lr if self.warmup_epochs > 0 else self.lr,
eta_min=self.min_lr,
),
"interval": "step",
"interval": self.scheduler_interval,
"frequency": 1,
}
elif self.scheduler == "reduce":
Expand All @@ -280,6 +286,24 @@ def configure_optimizers(self) -> Tuple[List, List]:

return [optimizer], [scheduler]

def forward(self, X: torch.tensor) -> Dict[str, Any]:
"""Performs forward pass of the frozen backbone and the linear layer for evaluation.
Args:
X (torch.tensor): a batch of images in the tensor format.
Returns:
Dict[str, Any]: a dict containing features and logits.
"""

if not self.no_channel_last:
X = X.to(memory_format=torch.channels_last)

with torch.no_grad():
feats = self.backbone(X)
logits = self.classifier(feats)
return {"logits": logits, "feats": feats}

def shared_step(
self, batch: Tuple, batch_idx: int
) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down
1 change: 0 additions & 1 deletion tests/args/test_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import os
import subprocess
import textwrap
from pathlib import Path

from solo.args.utils import additional_setup_linear, additional_setup_pretrain
from tests.dali.utils import DummyDataset
Expand Down