Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 31, 2022
1 parent 94e01eb commit bc81dbf
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion solo/backbones/vit/vit_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def initialize_weights(self):
# initialization
# initialize (and freeze) pos_embed by sin-cos embedding
pos_embed = generate_2d_sincos_pos_embed(
self.pos_embed.shape[-1], int(self.patch_embed.num_patches ** 0.5), cls_token=True
self.pos_embed.shape[-1], int(self.patch_embed.num_patches**0.5), cls_token=True
)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

Expand Down
2 changes: 1 addition & 1 deletion solo/backbones/vit/vit_mocov3.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def build_2d_sincos_position_embedding(self, temperature=10000.0):
), "Embed dimension must be divisible by 4 for 2D sin-cos position embedding"
pos_dim = self.embed_dim // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = 1.0 / (temperature ** omega)
omega = 1.0 / (temperature**omega)
out_w = torch.einsum("m,d->md", [grid_w.flatten(), omega])
out_h = torch.einsum("m,d->md", [grid_h.flatten(), omega])
pos_emb = torch.cat(
Expand Down
2 changes: 1 addition & 1 deletion solo/losses/mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def patchify(imgs: torch.Tensor, patch_size: int) -> torch.Tensor:
h = w = imgs.size(2) // patch_size
x = imgs.reshape(shape=(imgs.size(0), 3, h, patch_size, w, patch_size))
x = torch.einsum("nchpwq->nhwpqc", x)
x = x.reshape(shape=(imgs.size(0), h * w, patch_size ** 2 * 3))
x = x.reshape(shape=(imgs.size(0), h * w, patch_size**2 * 3))
return x


Expand Down
2 changes: 1 addition & 1 deletion solo/losses/vibcreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def covariance_loss(z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
fxf_cov_z2 = torch.mm(norm_z2.T, norm_z2)
fxf_cov_z1.fill_diagonal_(0.0)
fxf_cov_z2.fill_diagonal_(0.0)
cov_loss = (fxf_cov_z1 ** 2).mean() + (fxf_cov_z2 ** 2).mean()
cov_loss = (fxf_cov_z1**2).mean() + (fxf_cov_z2**2).mean()
return cov_loss


Expand Down
4 changes: 2 additions & 2 deletions solo/methods/mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
)

self.decoder_norm = nn.LayerNorm(embed_dim)
self.decoder_pred = nn.Linear(embed_dim, patch_size ** 2 * 3, bias=True)
self.decoder_pred = nn.Linear(embed_dim, patch_size**2 * 3, bias=True)

# init all weights according to MAE's repo
self.initialize_weights()
Expand All @@ -70,7 +70,7 @@ def initialize_weights(self):

decoder_pos_embed = generate_2d_sincos_pos_embed(
self.decoder_pos_embed.shape[-1],
int(self.num_patches ** 0.5),
int(self.num_patches**0.5),
cls_token=True,
)
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
Expand Down
2 changes: 1 addition & 1 deletion solo/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def generate_1d_sincos_pos_embed_from_grid(embed_dim, pos):
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float)
omega /= embed_dim / 2.0
omega = 1.0 / 10000 ** omega # (D/2,)
omega = 1.0 / 10000**omega # (D/2,)

pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
Expand Down

0 comments on commit bc81dbf

Please sign in to comment.