forked from recursionpharma/maes_microscopy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
vit.py
309 lines (269 loc) · 9.65 KB
/
vit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
# © Recursion Pharmaceuticals 2024
import timm.models.vision_transformer as vit
import torch
def generate_2d_sincos_pos_embeddings(
embedding_dim: int,
length: int,
scale: float = 10000.0,
use_class_token: bool = True,
num_modality: int = 1,
) -> torch.nn.Parameter:
"""
Generate 2Dimensional sin/cosine positional embeddings
Parameters
----------
embedding_dim : int
embedding dimension used in vit
length : int
number of tokens along height or width of image after patching (assuming square)
scale : float
scale for sin/cos functions
use_class_token : bool
True - add zero vector to be added to class_token, False - no vector added
num_modality: number of modalities. If 0, a single modality is assumed.
Otherwise one-hot modality encoding is added and sincos encoding size is appropriately reduced.
Returns
-------
positional_encoding : torch.Tensor
positional encoding to add to vit patch encodings
[num_modality*length*length, embedding_dim] or [1+num_modality*length*length, embedding_dim]
(w/ or w/o cls_token)
"""
linear_positions = torch.arange(length, dtype=torch.float32)
height_mesh, width_mesh = torch.meshgrid(
linear_positions, linear_positions, indexing="ij"
)
positional_dim = embedding_dim // 4 # accomodate h and w x cos and sin embeddings
positional_weights = (
torch.arange(positional_dim, dtype=torch.float32) / positional_dim
)
positional_weights = 1.0 / (scale**positional_weights)
height_weights = torch.outer(height_mesh.flatten(), positional_weights)
width_weights = torch.outer(width_mesh.flatten(), positional_weights)
positional_encoding = torch.cat(
[
torch.sin(height_weights),
torch.cos(height_weights),
torch.sin(width_weights),
torch.cos(width_weights),
],
dim=1,
)[None, :, :]
# repeat positional encoding for multiple channel modalities
positional_encoding = positional_encoding.repeat(1, num_modality, 1)
if use_class_token:
class_token = torch.zeros([1, 1, embedding_dim], dtype=torch.float32)
positional_encoding = torch.cat([class_token, positional_encoding], dim=1)
positional_encoding = torch.nn.Parameter(positional_encoding, requires_grad=False)
return positional_encoding
class ChannelAgnosticPatchEmbed(vit.PatchEmbed): # type: ignore[misc]
def __init__(
self,
img_size: int,
patch_size: int,
embed_dim: int,
bias: bool = True,
) -> None:
super().__init__(
img_size=img_size,
patch_size=patch_size,
in_chans=1, # in_chans is used by self.proj, which we override anyway
embed_dim=embed_dim,
norm_layer=None,
flatten=False,
bias=bias,
)
# channel-agnostic MAE has a single projection for all chans
self.proj = torch.nn.Conv2d(
1, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
in_chans = x.shape[1]
x = torch.stack(
[self.proj(x[:, i : i + 1]) for i in range(in_chans)], dim=2
) # single project for all chans
x = x.flatten(2).transpose(1, 2) # BCMHW -> BNC
return x
class ChannelAgnosticViT(vit.VisionTransformer): # type: ignore[misc]
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
# rewrite https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L586
to_cat = []
if self.cls_token is not None:
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
# TODO: upgrade timm to get access to register tokens
# if self.vit_backbone.reg_token is not None:
# to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
# MAIN DIFFERENCE with Timm - we DYNAMICALLY ADDING POS EMBEDDINGS based on shape of inputs
# this supports having CA-MAEs actually be channel-agnostic at inference time
if self.no_embed_class:
x = x + self.pos_embed[:, : x.shape[1]]
if to_cat:
x = torch.cat(to_cat + [x], dim=1)
else:
if to_cat:
x = torch.cat(to_cat + [x], dim=1)
x = x + self.pos_embed[:, : x.shape[1]]
return self.pos_drop(x) # type: ignore[no-any-return]
def channel_agnostic_vit(
vit_backbone: vit.VisionTransformer, max_in_chans: int
) -> vit.VisionTransformer:
# replace patch embedding with channel-agnostic version
vit_backbone.patch_embed = ChannelAgnosticPatchEmbed(
img_size=vit_backbone.patch_embed.img_size[0],
patch_size=vit_backbone.patch_embed.patch_size[0],
embed_dim=vit_backbone.embed_dim,
)
# replace positional embedding with channel-agnostic version
vit_backbone.pos_embed = generate_2d_sincos_pos_embeddings(
embedding_dim=vit_backbone.embed_dim,
length=vit_backbone.patch_embed.grid_size[0],
use_class_token=vit_backbone.cls_token is not None,
num_modality=max_in_chans,
)
# change the class to be ChannelAgnostic so that it actually uses the new _pos_embed
vit_backbone.__class__ = ChannelAgnosticViT
return vit_backbone
def sincos_positional_encoding_vit(
vit_backbone: vit.VisionTransformer, scale: float = 10000.0
) -> vit.VisionTransformer:
"""Attaches no-grad sin-cos positional embeddings to a pre-constructed ViT backbone model.
Parameters
----------
vit_backbone : timm.models.vision_transformer.VisionTransformer
the constructed vision transformer from timm
scale : float (default 10000.0)
hyperparameter for sincos positional embeddings, recommend keeping at 10,000
Returns
-------
timm.models.vision_transformer.VisionTransformer
the same ViT but with fixed no-grad positional encodings to add to vit patch encodings
"""
# length: number of tokens along height or width of image after patching (assuming square)
length = (
vit_backbone.patch_embed.img_size[0] // vit_backbone.patch_embed.patch_size[0]
)
pos_embeddings = generate_2d_sincos_pos_embeddings(
vit_backbone.embed_dim,
length=length,
scale=scale,
use_class_token=vit_backbone.cls_token is not None,
)
# note, if the model had weight_init == 'skip', this might get overwritten
vit_backbone.pos_embed = pos_embeddings
return vit_backbone
def vit_small_patch16_256(**kwargs):
default_kwargs = dict(
img_size=256,
in_chans=6,
num_classes=0,
fc_norm=None,
class_token=True,
drop_path_rate=0.1,
init_values=0.0001,
block_fn=vit.ParallelScalingBlock,
qkv_bias=False,
qk_norm=True,
)
for k, v in kwargs.items():
default_kwargs[k] = v
return vit.vit_small_patch16_224(**default_kwargs)
def vit_small_patch32_512(**kwargs):
default_kwargs = dict(
img_size=512,
in_chans=6,
num_classes=0,
fc_norm=None,
class_token=True,
drop_path_rate=0.1,
init_values=0.0001,
block_fn=vit.ParallelScalingBlock,
qkv_bias=False,
qk_norm=True,
)
for k, v in kwargs.items():
default_kwargs[k] = v
return vit.vit_small_patch32_384(**default_kwargs)
def vit_base_patch8_256(**kwargs):
default_kwargs = dict(
img_size=256,
in_chans=6,
num_classes=0,
fc_norm=None,
class_token=True,
drop_path_rate=0.1,
init_values=0.0001,
block_fn=vit.ParallelScalingBlock,
qkv_bias=False,
qk_norm=True,
)
for k, v in kwargs.items():
default_kwargs[k] = v
return vit.vit_base_patch8_224(**default_kwargs)
def vit_base_patch16_256(**kwargs):
default_kwargs = dict(
img_size=256,
in_chans=6,
num_classes=0,
fc_norm=None,
class_token=True,
drop_path_rate=0.1,
init_values=0.0001,
block_fn=vit.ParallelScalingBlock,
qkv_bias=False,
qk_norm=True,
)
for k, v in kwargs.items():
default_kwargs[k] = v
return vit.vit_base_patch16_224(**default_kwargs)
def vit_base_patch32_512(**kwargs):
default_kwargs = dict(
img_size=512,
in_chans=6,
num_classes=0,
fc_norm=None,
class_token=True,
drop_path_rate=0.1,
init_values=0.0001,
block_fn=vit.ParallelScalingBlock,
qkv_bias=False,
qk_norm=True,
)
for k, v in kwargs.items():
default_kwargs[k] = v
return vit.vit_base_patch32_384(**default_kwargs)
def vit_large_patch8_256(**kwargs):
default_kwargs = dict(
img_size=256,
in_chans=6,
num_classes=0,
fc_norm=None,
class_token=True,
patch_size=8,
embed_dim=1024,
depth=24,
num_heads=16,
drop_path_rate=0.3,
init_values=0.0001,
block_fn=vit.ParallelScalingBlock,
qkv_bias=False,
qk_norm=True,
)
for k, v in kwargs.items():
default_kwargs[k] = v
return vit.VisionTransformer(**default_kwargs)
def vit_large_patch16_256(**kwargs):
default_kwargs = dict(
img_size=256,
in_chans=6,
num_classes=0,
fc_norm=None,
class_token=True,
drop_path_rate=0.3,
init_values=0.0001,
block_fn=vit.ParallelScalingBlock,
qkv_bias=False,
qk_norm=True,
)
for k, v in kwargs.items():
default_kwargs[k] = v
return vit.vit_large_patch16_384(**default_kwargs)