Skip to content

Commit

Permalink
Added FroSSL
Browse files Browse the repository at this point in the history
  • Loading branch information
OFSkean committed Aug 4, 2024
1 parent b69b4bd commit b67c7eb
Show file tree
Hide file tree
Showing 8 changed files with 311 additions and 1 deletion.
2 changes: 1 addition & 1 deletion main_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def main(cfg: DictConfig):
assert cfg.method in METHODS, f"Choose from {METHODS.keys()}"

if cfg.data.num_large_crops != 2:
assert cfg.method in ["wmse", "mae"]
assert cfg.method in ["wmse", "mae", "frossl"]

model = METHODS[cfg.method](cfg)
make_contiguous(model)
Expand Down
54 changes: 54 additions & 0 deletions scripts/pretrain/cifar/frossl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
defaults:
- _self_
- augmentations: symmetric.yaml
- wandb: private.yaml
- override hydra/hydra_logging: disabled
- override hydra/job_logging: disabled

# disable hydra outputs
hydra:
output_subdir: null
run:
dir: .

name: "frossl-cifar10" # change here for cifar100
method: "frossl"
backbone:
name: "resnet18"
method_kwargs:
proj_hidden_dim: 2048
proj_output_dim: 1024
invariance_weight: 1.4

data:
dataset: cifar10 # change here for cifar100
train_path: "./datasets"
val_path: "./datasets"
format: "image_folder"
num_workers: 8
optimizer:
name: "lars"
batch_size: 256
lr: 0.3
classifier_lr: 0.1
weight_decay: 1e-4
kwargs:
clip_lr: True
eta: 0.02
exclude_bias_n_norm: True
scheduler:
name: "warmup_cosine"
checkpoint:
enabled: True
dir: "trained_models"
frequency: 1
auto_resume:
enabled: True

# overwrite PL stuff
max_epochs: 1000
devices: [0]
sync_batchnorm: True
accelerator: "gpu"
strategy: "ddp"
precision: 16-mixed
53 changes: 53 additions & 0 deletions scripts/pretrain/imagenet-100/frossl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
defaults:
- _self_
- augmentations: symmetric.yaml
- wandb: private.yaml
- override hydra/hydra_logging: disabled
- override hydra/job_logging: disabled

# disable hydra outputs
hydra:
output_subdir: null
run:
dir: .

name: "frossl-imagenet100"
method: "frossl"
backbone:
name: "resnet18"
method_kwargs:
proj_hidden_dim: 2048
proj_output_dim: 1024
invariance_weight: 2.0
data:
dataset: imagenet100
train_path: "./datasets/imagenet100/train"
val_path: "./datasets/imagenet100/val"
format: "dali"
num_workers: 16
optimizer:
name: "lars"
batch_size: 256
lr: 0.3
classifier_lr: 0.1
weight_decay: 1e-4
kwargs:
clip_lr: True
eta: 0.02
exclude_bias_n_norm: True
scheduler:
name: "warmup_cosine"
checkpoint:
enabled: True
dir: "trained_models"
frequency: 1
auto_resume:
enabled: True

# overwrite PL stuff
max_epochs: 400
devices: [0, 1]
sync_batchnorm: True
accelerator: "gpu"
strategy: "ddp"
precision: 16-mixed
54 changes: 54 additions & 0 deletions scripts/pretrain/imagenet/frossl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
defaults:
- _self_
- augmentations: vicreg.yaml
- wandb: private.yaml
- override hydra/hydra_logging: disabled
- override hydra/job_logging: disabled

# disable hydra outputs
hydra:
output_subdir: null
run:
dir: .

name: "frossl-imagenet"
method: "frossl"
backbone:
name: "resnet18"
method_kwargs:
proj_hidden_dim: 2048
proj_output_dim: 1024
invariance_weight: 2.0

data:
dataset: imagenet
train_path: "./datasets/imagenet/train"
val_path: "./datasets/imagenet/val"
format: "dali"
num_workers: 8
optimizer:
name: "lars"
batch_size: 256
lr: 0.3
classifier_lr: 0.1
weight_decay: 1e-4
kwargs:
clip_lr: True
eta: 0.02
exclude_bias_n_norm: True
scheduler:
name: "warmup_cosine"
checkpoint:
enabled: True
dir: "trained_models"
frequency: 1
auto_resume:
enabled: True

# overwrite PL stuff
max_epochs: 100
devices: [0, 1]
sync_batchnorm: True
accelerator: "gpu"
strategy: "ddp"
precision: 16-mixed
2 changes: 2 additions & 0 deletions solo/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from solo.losses.byol import byol_loss_func
from solo.losses.deepclusterv2 import deepclusterv2_loss_func
from solo.losses.dino import DINOLoss
from solo.losses.frossl import frossl_loss_func
from solo.losses.mae import mae_loss_func
from solo.losses.mocov2plus import mocov2plus_loss_func
from solo.losses.mocov3 import mocov3_loss_func
Expand All @@ -38,6 +39,7 @@
"byol_loss_func",
"deepclusterv2_loss_func",
"DINOLoss",
"frossl_loss_func",
"mae_loss_func",
"mocov2plus_loss_func",
"mocov3_loss_func",
Expand Down
39 changes: 39 additions & 0 deletions solo/losses/frossl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Any, List, Sequence, Dict
import torch
import torch.distributed as dist
import torch.nn.functional as F

def frossl_loss_func(
z: torch.Tensor, invariance_weight=1
) -> torch.Tensor:
"""Computes FroSSL's loss given batch of projected features z
from num_crops different views.
Args:
z (torch.Tensor): views x N x D Tensor containing projected features from the views.
Every Nth sample is a different view of the same image.
invariance_weight (float): weight for the invariance loss term. default is 1.
Return:
torch.Tensor: FroSSL loss.
"""
V, N, D = z.shape

z = F.normalize(z, dim=-1) # V x N x D

if N > D:
cov = view_embeddings.T @ view_embeddings # V x D x D
else:
cov = view_embeddings @ view_embeddings.T # V x N x N
cov = cov / torch.trace(cov)

# sum the log-frobenius norm of each view covariance matrix
fro_norm_per_view = torch.linalg.norm(cov, ord='fro') # V x 1
regularization_term = torch.sum( -2*torch.log(fro_norm) ) # bring frobenius square outside log

# align each view to the average view
average_z = torch.mean(z, dim=0) # N x D, samples are averaged across views
invariance_loss_term = F.mse_loss(z, average_z)

total_loss = regularization_term + invariance_weight*invariance_loss_term
return total_loss
2 changes: 2 additions & 0 deletions solo/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from solo.methods.byol import BYOL
from solo.methods.deepclusterv2 import DeepClusterV2
from solo.methods.dino import DINO
from solo.methods.frossl import FroSSL
from solo.methods.linear import LinearModel
from solo.methods.mae import MAE
from solo.methods.mocov2plus import MoCoV2Plus
Expand Down Expand Up @@ -49,6 +50,7 @@
"byol": BYOL,
"deepclusterv2": DeepClusterV2,
"dino": DINO,
"frossl": FroSSL,
"mae": MAE,
"mocov2plus": MoCoV2Plus,
"mocov3": MoCoV3,
Expand Down
106 changes: 106 additions & 0 deletions solo/methods/frossl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import Any, List, Sequence, Dict

import omegaconf
import torch
import torch.nn as nn
from solo.methods.base import BaseMethod
from solo.utils.misc import omegaconf_select, gather
from solo.losses.frossl import frossl_loss_func

class FroSSL(BaseMethod):
def __init__(self, cfg: omegaconf.DictConfig):
"""Implements FroSSL (https://arxiv.org/pdf/2310.02903)
Extra cfg settings:
method_kwargs:
proj_hidden_dim (int): number of neurons of the hidden layers of the projector.
proj_output_dim (int): number of dimensions of projected features.
invariance_weight (float): weight of the invariance loss term.
"""

super().__init__(cfg)

self.invariance_weight: float = cfg.method_kwargs.invariance_weight

proj_hidden_dim: int = cfg.method_kwargs.proj_hidden_dim
proj_output_dim: int = cfg.method_kwargs.proj_output_dim

# projector
self.projector = nn.Sequential(
nn.Linear(self.features_dim, proj_hidden_dim),
nn.BatchNorm1d(proj_hidden_dim),
nn.ReLU(),
nn.Linear(proj_hidden_dim, proj_hidden_dim),
nn.BatchNorm1d(proj_hidden_dim),
nn.ReLU(),
nn.Linear(proj_hidden_dim, proj_output_dim),
)

@staticmethod
def add_and_assert_specific_cfg(cfg: omegaconf.DictConfig) -> omegaconf.DictConfig:
"""Adds method specific default values/checks for config.
Args:
cfg (omegaconf.DictConfig): DictConfig object.
Returns:
omegaconf.DictConfig: same as the argument, used to avoid errors.
"""

cfg = super(FroSSL, FroSSL).add_and_assert_specific_cfg(cfg)

assert not omegaconf.OmegaConf.is_missing(cfg, "method_kwargs.proj_hidden_dim")
assert not omegaconf.OmegaConf.is_missing(cfg, "method_kwargs.proj_output_dim")

cfg.method_kwargs.invariance_weight = omegaconf_select(cfg, "method_kwargs.invariance_weight", 1.0)
return cfg

@property
def learnable_params(self) -> List[dict]:
"""Adds projector parameters to parent's learnable parameters.
Returns:
List[dict]: list of learnable parameters.
"""

extra_learnable_params = [{"name": "projector", "params": self.projector.parameters()}]
return super().learnable_params + extra_learnable_params

def forward(self, X):
"""Performs the forward pass of the backbone and the projector.
Args:
X (torch.Tensor): a batch of images in the tensor format.
Returns:
Dict[str, Any]: a dict containing the outputs of the parent and the projected features.
"""

out = super().forward(X)
z = self.projector(out["feats"])
out.update({"z": z})
return out

def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
"""Training step for FroSSL reusing BaseMethod training step.
Args:
batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where
[X] is a list of size num_crops containing batches of images.
batch_idx (int): index of the batch.
Returns:
torch.Tensor: total loss composed of FroSSL loss and classification loss.
"""

out = super().training_step(batch, batch_idx)
class_loss = out["loss"]

z = torch.stack(out["z"], dim=0) # V x N_per_gpu x D
z = torch.gather(z, dim=1) # V x N_total x D

frossl_loss = frossl_loss_func(z, invariance_weight=self.invariance_weight)
self.log("train_frossl_loss", frossl_loss, on_epoch=True, sync_dist=True)

return frossl_loss + class_loss

0 comments on commit b67c7eb

Please sign in to comment.