Skip to content

Commit

Permalink
diff_traj: freeze pretrained weights
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Jul 7, 2024
1 parent 2accf68 commit a929c9e
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 4 deletions.
4 changes: 2 additions & 2 deletions torchdrive/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,8 @@ def create_model(
num_frames=self.num_frames,
).to(device)

# for cam_encoder in model.encoders.values():
# freeze(cam_encoder.encoder)
for cam_encoder in model.encoders.values():
cam_encoder.freeze_pretrained_weights()

return model

Expand Down
25 changes: 25 additions & 0 deletions torchworld/models/test_vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import unittest

import torch
from torchworld.models.vit import MaskViT


class TestViT(unittest.TestCase):
def test_vit_mask(self):
m = MaskViT(
attention_dropout=0.1,
cam_shape=(48, 64),
dim=16,
weights=None,
)
x = torch.rand(1, 3, 48, 64)
mask = torch.ones(3, 4, dtype=torch.bool)

features, out = m(x, mask)
self.assertEqual(features.shape, (1, 768, 3, 4))
self.assertEqual(out.shape, (1, 3 * 4, 16))

m.freeze_pretrained_weights()
needs_grad = [param for param in m.parameters() if param.requires_grad]
# positional embedding + linear weight/bias
self.assertEqual(len(needs_grad), 3)
14 changes: 12 additions & 2 deletions torchworld/models/vit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple
from typing import Optional, Tuple

import torch
from torch import nn
Expand All @@ -11,12 +11,13 @@ def __init__(
attention_dropout: float,
cam_shape: Tuple[int, int],
dim: int,
weights: Optional[object] = ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1,
) -> None:
super().__init__()

self.cam_shape = cam_shape
self.encoder = vit_b_16(
weights=ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1,
weights=weights,
progress=True,
attention_dropout=attention_dropout,
)
Expand All @@ -30,6 +31,13 @@ def __init__(
)
self.project = nn.Linear(self.encoder.hidden_dim, dim)

def freeze_pretrained_weights(self) -> None:
for param in self.encoder.parameters():
# skip pos embedding since we overwrote it
if param is self.encoder.encoder.pos_embedding:
continue
param.requires_grad = False

def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# Reshape and permute the input tensor
n, c, h, w = x.shape
Expand All @@ -44,6 +52,8 @@ def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:

unmasked = x

print(x.shape, mask.shape)

# (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, mask.sum())
x = x[:, :, mask]

Expand Down

0 comments on commit a929c9e

Please sign in to comment.