-
Notifications
You must be signed in to change notification settings - Fork 186
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
311 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
defaults: | ||
- _self_ | ||
- augmentations: symmetric.yaml | ||
- wandb: private.yaml | ||
- override hydra/hydra_logging: disabled | ||
- override hydra/job_logging: disabled | ||
|
||
# disable hydra outputs | ||
hydra: | ||
output_subdir: null | ||
run: | ||
dir: . | ||
|
||
name: "frossl-cifar10" # change here for cifar100 | ||
method: "frossl" | ||
backbone: | ||
name: "resnet18" | ||
method_kwargs: | ||
proj_hidden_dim: 2048 | ||
proj_output_dim: 1024 | ||
invariance_weight: 1.4 | ||
|
||
data: | ||
dataset: cifar10 # change here for cifar100 | ||
train_path: "./datasets" | ||
val_path: "./datasets" | ||
format: "image_folder" | ||
num_workers: 8 | ||
optimizer: | ||
name: "lars" | ||
batch_size: 256 | ||
lr: 0.3 | ||
classifier_lr: 0.1 | ||
weight_decay: 1e-4 | ||
kwargs: | ||
clip_lr: True | ||
eta: 0.02 | ||
exclude_bias_n_norm: True | ||
scheduler: | ||
name: "warmup_cosine" | ||
checkpoint: | ||
enabled: True | ||
dir: "trained_models" | ||
frequency: 1 | ||
auto_resume: | ||
enabled: True | ||
|
||
# overwrite PL stuff | ||
max_epochs: 1000 | ||
devices: [0] | ||
sync_batchnorm: True | ||
accelerator: "gpu" | ||
strategy: "ddp" | ||
precision: 16-mixed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
defaults: | ||
- _self_ | ||
- augmentations: symmetric.yaml | ||
- wandb: private.yaml | ||
- override hydra/hydra_logging: disabled | ||
- override hydra/job_logging: disabled | ||
|
||
# disable hydra outputs | ||
hydra: | ||
output_subdir: null | ||
run: | ||
dir: . | ||
|
||
name: "frossl-imagenet100" | ||
method: "frossl" | ||
backbone: | ||
name: "resnet18" | ||
method_kwargs: | ||
proj_hidden_dim: 2048 | ||
proj_output_dim: 1024 | ||
invariance_weight: 2.0 | ||
data: | ||
dataset: imagenet100 | ||
train_path: "./datasets/imagenet100/train" | ||
val_path: "./datasets/imagenet100/val" | ||
format: "dali" | ||
num_workers: 16 | ||
optimizer: | ||
name: "lars" | ||
batch_size: 256 | ||
lr: 0.3 | ||
classifier_lr: 0.1 | ||
weight_decay: 1e-4 | ||
kwargs: | ||
clip_lr: True | ||
eta: 0.02 | ||
exclude_bias_n_norm: True | ||
scheduler: | ||
name: "warmup_cosine" | ||
checkpoint: | ||
enabled: True | ||
dir: "trained_models" | ||
frequency: 1 | ||
auto_resume: | ||
enabled: True | ||
|
||
# overwrite PL stuff | ||
max_epochs: 400 | ||
devices: [0, 1] | ||
sync_batchnorm: True | ||
accelerator: "gpu" | ||
strategy: "ddp" | ||
precision: 16-mixed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
defaults: | ||
- _self_ | ||
- augmentations: vicreg.yaml | ||
- wandb: private.yaml | ||
- override hydra/hydra_logging: disabled | ||
- override hydra/job_logging: disabled | ||
|
||
# disable hydra outputs | ||
hydra: | ||
output_subdir: null | ||
run: | ||
dir: . | ||
|
||
name: "frossl-imagenet" | ||
method: "frossl" | ||
backbone: | ||
name: "resnet18" | ||
method_kwargs: | ||
proj_hidden_dim: 2048 | ||
proj_output_dim: 1024 | ||
invariance_weight: 2.0 | ||
|
||
data: | ||
dataset: imagenet | ||
train_path: "./datasets/imagenet/train" | ||
val_path: "./datasets/imagenet/val" | ||
format: "dali" | ||
num_workers: 8 | ||
optimizer: | ||
name: "lars" | ||
batch_size: 256 | ||
lr: 0.3 | ||
classifier_lr: 0.1 | ||
weight_decay: 1e-4 | ||
kwargs: | ||
clip_lr: True | ||
eta: 0.02 | ||
exclude_bias_n_norm: True | ||
scheduler: | ||
name: "warmup_cosine" | ||
checkpoint: | ||
enabled: True | ||
dir: "trained_models" | ||
frequency: 1 | ||
auto_resume: | ||
enabled: True | ||
|
||
# overwrite PL stuff | ||
max_epochs: 100 | ||
devices: [0, 1] | ||
sync_batchnorm: True | ||
accelerator: "gpu" | ||
strategy: "ddp" | ||
precision: 16-mixed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from typing import Any, List, Sequence, Dict | ||
import torch | ||
import torch.distributed as dist | ||
import torch.nn.functional as F | ||
|
||
def frossl_loss_func( | ||
z: torch.Tensor, invariance_weight=1 | ||
) -> torch.Tensor: | ||
"""Computes FroSSL's loss given batch of projected features z | ||
from num_crops different views. | ||
Args: | ||
z (torch.Tensor): views x N x D Tensor containing projected features from the views. | ||
Every Nth sample is a different view of the same image. | ||
invariance_weight (float): weight for the invariance loss term. default is 1. | ||
Return: | ||
torch.Tensor: FroSSL loss. | ||
""" | ||
V, N, D = z.shape | ||
|
||
z = F.normalize(z, dim=-1) # V x N x D | ||
|
||
if N > D: | ||
cov = view_embeddings.T @ view_embeddings # V x D x D | ||
else: | ||
cov = view_embeddings @ view_embeddings.T # V x N x N | ||
cov = cov / torch.trace(cov) | ||
|
||
# sum the log-frobenius norm of each view covariance matrix | ||
fro_norm_per_view = torch.linalg.norm(cov, ord='fro') # V x 1 | ||
regularization_term = torch.sum( -2*torch.log(fro_norm) ) # bring frobenius square outside log | ||
|
||
# align each view to the average view | ||
average_z = torch.mean(z, dim=0) # N x D, samples are averaged across views | ||
invariance_loss_term = F.mse_loss(z, average_z) | ||
|
||
total_loss = regularization_term + invariance_weight*invariance_loss_term | ||
return total_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
from typing import Any, List, Sequence, Dict | ||
|
||
import omegaconf | ||
import torch | ||
import torch.nn as nn | ||
from solo.methods.base import BaseMethod | ||
from solo.utils.misc import omegaconf_select, gather | ||
from solo.losses.frossl import frossl_loss_func | ||
|
||
class FroSSL(BaseMethod): | ||
def __init__(self, cfg: omegaconf.DictConfig): | ||
"""Implements FroSSL (https://arxiv.org/pdf/2310.02903) | ||
Extra cfg settings: | ||
method_kwargs: | ||
proj_hidden_dim (int): number of neurons of the hidden layers of the projector. | ||
proj_output_dim (int): number of dimensions of projected features. | ||
invariance_weight (float): weight of the invariance loss term. | ||
""" | ||
|
||
super().__init__(cfg) | ||
|
||
self.invariance_weight: float = cfg.method_kwargs.invariance_weight | ||
|
||
proj_hidden_dim: int = cfg.method_kwargs.proj_hidden_dim | ||
proj_output_dim: int = cfg.method_kwargs.proj_output_dim | ||
|
||
# projector | ||
self.projector = nn.Sequential( | ||
nn.Linear(self.features_dim, proj_hidden_dim), | ||
nn.BatchNorm1d(proj_hidden_dim), | ||
nn.ReLU(), | ||
nn.Linear(proj_hidden_dim, proj_hidden_dim), | ||
nn.BatchNorm1d(proj_hidden_dim), | ||
nn.ReLU(), | ||
nn.Linear(proj_hidden_dim, proj_output_dim), | ||
) | ||
|
||
@staticmethod | ||
def add_and_assert_specific_cfg(cfg: omegaconf.DictConfig) -> omegaconf.DictConfig: | ||
"""Adds method specific default values/checks for config. | ||
Args: | ||
cfg (omegaconf.DictConfig): DictConfig object. | ||
Returns: | ||
omegaconf.DictConfig: same as the argument, used to avoid errors. | ||
""" | ||
|
||
cfg = super(FroSSL, FroSSL).add_and_assert_specific_cfg(cfg) | ||
|
||
assert not omegaconf.OmegaConf.is_missing(cfg, "method_kwargs.proj_hidden_dim") | ||
assert not omegaconf.OmegaConf.is_missing(cfg, "method_kwargs.proj_output_dim") | ||
|
||
cfg.method_kwargs.invariance_weight = omegaconf_select(cfg, "method_kwargs.invariance_weight", 1.0) | ||
return cfg | ||
|
||
@property | ||
def learnable_params(self) -> List[dict]: | ||
"""Adds projector parameters to parent's learnable parameters. | ||
Returns: | ||
List[dict]: list of learnable parameters. | ||
""" | ||
|
||
extra_learnable_params = [{"name": "projector", "params": self.projector.parameters()}] | ||
return super().learnable_params + extra_learnable_params | ||
|
||
def forward(self, X): | ||
"""Performs the forward pass of the backbone and the projector. | ||
Args: | ||
X (torch.Tensor): a batch of images in the tensor format. | ||
Returns: | ||
Dict[str, Any]: a dict containing the outputs of the parent and the projected features. | ||
""" | ||
|
||
out = super().forward(X) | ||
z = self.projector(out["feats"]) | ||
out.update({"z": z}) | ||
return out | ||
|
||
def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: | ||
"""Training step for FroSSL reusing BaseMethod training step. | ||
Args: | ||
batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where | ||
[X] is a list of size num_crops containing batches of images. | ||
batch_idx (int): index of the batch. | ||
Returns: | ||
torch.Tensor: total loss composed of FroSSL loss and classification loss. | ||
""" | ||
|
||
out = super().training_step(batch, batch_idx) | ||
class_loss = out["loss"] | ||
|
||
z = torch.stack(out["z"], dim=0) # V x N_per_gpu x D | ||
z = torch.gather(z, dim=1) # V x N_total x D | ||
|
||
frossl_loss = frossl_loss_func(z, invariance_weight=self.invariance_weight) | ||
self.log("train_frossl_loss", frossl_loss, on_epoch=True, sync_dist=True) | ||
|
||
return frossl_loss + class_loss |