-
Notifications
You must be signed in to change notification settings - Fork 0
/
models_vim.py
984 lines (884 loc) · 41.4 KB
/
models_vim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
import torch
import torch.nn as nn
from functools import partial
from torch import Tensor
from typing import Optional
from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_, lecun_normal_
from timm.models.layers import DropPath, to_2tuple
from timm.models.vision_transformer import _load_weights
import math
from collections import namedtuple
from mamba_ssm.modules.mamba_simple import Mamba
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
from mamba_ssm.modules.mamba2 import Mamba2
from mamba_ssm.modules.mha import MHA
from rope import *
import random
try:
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
__all__ = [
'vim_tiny_patch16_224', 'vim_small_patch16_224', 'vim_base_patch16_224',
'vim_tiny_patch16_384', 'vim_small_patch16_384', 'vim_base_patch16_384',
]
def traid(t):
return t if isinstance(t, tuple) else (t, t, t)
class PatchEmbed3D(nn.Module):
""" 3D Image to Patch Embedding
"""
def __init__(self, volume_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
super().__init__()
volume_size = traid(volume_size)
patch_size = traid(patch_size)
self.volume_size = volume_size
self.patch_size = patch_size
self.grid_size = (volume_size[0] // patch_size[0], volume_size[1] // patch_size[1], volume_size[2] // patch_size[2])
self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
self.flatten = flatten
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, L, H, W = x.shape
assert L == self.volume_size[0] and H == self.volume_size[1] and W == self.volume_size[2], \
f"Volume image size ({L}*{H}*{W}) doesn't match model ({self.volume_size[0]}*{self.volume_size[1]}*{self.volume_size[2]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
class PatchEmbedUniversial(nn.Module):
""" 2D Image to Patch Embedding"""
def __init__(self, img_size=224, in_chans=3, embed_dim=768, norm_layer=None, if_3d=False):
super().__init__()
patch_size = 1
if not if_3d:
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
else:
img_size = traid(img_size)
patch_size = traid(patch_size)
self.img_size = img_size
self.patch_size = patch_size
# self.grid_size = ((img_size[0] - patch_size[0]) // stride + 1, (img_size[1] - patch_size[1]) // stride + 1)
self.num_patches = self.img_size[0] * self.img_size[1] if not if_3d else self.img_size[0] * self.img_size[1] * self.img_size[2]
self.if_3d = if_3d
# self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride)
self.proj = nn.Linear(in_chans, embed_dim)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
if not self.if_3d:
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
else:
B, C, L, H, W = x.shape
assert L == self.img_size[0] and H == self.img_size[1] and W == self.img_size[2], \
f"Volume image size ({L}*{H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}*{self.img_size[2]})."
# patch size is 1, so we can just flatten the spatial dimensions
x = x.view(B, C, -1).transpose(1, 2)
x = self.proj(x)
x = self.norm(x)
return x
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, stride=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = ((img_size[0] - patch_size[0]) // stride + 1, (img_size[1] - patch_size[1]) // stride + 1)
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
class Block(nn.Module):
def __init__(
self, dim, mixer_cls, mlp_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False, drop_path=0.,
):
"""
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
This Block has a slightly different structure compared to a regular
prenorm Transformer block.
The standard block is: LN -> MHA/MLP -> Add.
[Ref: https://arxiv.org/abs/2002.04745]
Here we have: Add -> LN -> Mixer, returning both
the hidden_states (output of the mixer) and the residual.
This is purely for performance reasons, as we can fuse add and LayerNorm.
The residual needs to be provided (except for the very first block).
"""
super().__init__()
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.norm = norm_cls(dim)
self.mixer = mixer_cls(dim)
if mlp_cls is not nn.Identity:
self.norm2 = norm_cls(dim)
self.mlp = mlp_cls(dim)
else:
self.mlp = None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
if self.fused_add_norm:
assert RMSNorm is not None, "RMSNorm import fails"
assert isinstance(
self.norm, (nn.LayerNorm, RMSNorm)
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
def forward(
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None, **mixer_kwargs
):
r"""Pass the input through the encoder layer.
Args:
hidden_states: the sequence to the encoder layer (required).
residual: hidden_states = Mixer(LN(residual))
"""
if not self.fused_add_norm:
residual = (self.drop_path(hidden_states) + residual) if residual is not None else hidden_states
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
hidden_states, residual = layer_norm_fn(
self.drop_path(hidden_states),
self.norm.weight,
self.norm.bias,
residual=residual,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm.eps,
is_rms_norm=isinstance(self.norm, RMSNorm)
)
# print('in block', hidden_states.shape)
hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs)
if self.mlp is not None:
if not self.fused_add_norm:
residual = self.drop_path(hidden_states) + residual
residual = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
hidden_states, residual = layer_norm_fn(
self.drop_path(hidden_states),
self.norm2.weight,
self.norm2.bias,
residual=residual,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm2.eps,
is_rms_norm=isinstance(self.norm2, RMSNorm)
)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
def create_block(
d_model,
ssm_cfg=None,
attn_layer_idx=None,
attn_cfg=None,
norm_epsilon=1e-5,
drop_path=0.,
rms_norm=False,
residual_in_fp32=False,
fused_add_norm=False,
layer_idx=None,
device=None,
dtype=None,
if_bimamba=False,
bimamba_type="none",
if_devide_out=False,
init_layer_scale=None,
mixer_type = 'mamba',
):
if if_bimamba:
bimamba_type = "v1"
if ssm_cfg is None:
ssm_cfg = {}
factory_kwargs = {"device": device, "dtype": dtype}
# print('mixer_type: ', mixer_type)
if (attn_layer_idx is None) or (layer_idx not in attn_layer_idx):
if mixer_type == 'mamba':
# mixer_cls = partial(Mamba, layer_idx=layer_idx, bimamba_type=bimamba_type, if_devide_out=if_devide_out, init_layer_scale=init_layer_scale, **ssm_cfg, **factory_kwargs)
mixer_cls = partial(Mamba, layer_idx=layer_idx, bimamba_type=bimamba_type, if_devide_out=if_devide_out, init_layer_scale=init_layer_scale, **ssm_cfg, **factory_kwargs)
elif mixer_type == 'mamba2':
mixer_cls = partial(Mamba2, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
else:
mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
# we don't need the mlp_cls
mlp_cls = nn.Identity
norm_cls = partial(
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
)
block = Block(
d_model,
mixer_cls,
mlp_cls,
norm_cls=norm_cls,
drop_path=drop_path,
fused_add_norm=fused_add_norm,
residual_in_fp32=residual_in_fp32,
)
block.layer_idx = layer_idx
return block
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(
module,
n_layer,
initializer_range=0.02, # Now only used for embedding layer.
rescale_prenorm_residual=True,
n_residuals_per_layer=1, # Change to 2 if we have MLP
):
if isinstance(module, nn.Linear):
if module.bias is not None:
if not getattr(module.bias, "_no_reinit", False):
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=initializer_range)
if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["out_proj.weight", "fc2.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
with torch.no_grad():
p /= math.sqrt(n_residuals_per_layer * n_layer)
def segm_init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
# NOTE conv was left to pytorch default in my original init
lecun_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
nn.init.zeros_(m.bias)
nn.init.ones_(m.weight)
class VisionMamba(nn.Module):
def __init__(self,
img_size=224,
patch_size=16,
stride=16,
depth=24,
embed_dim=192,
channels=3,
num_classes=1000,
attn_layer_idx=None,
attn_cfg=None,
ssm_cfg=None,
drop_rate=0.,
drop_path_rate=0.1,
norm_epsilon: float = 1e-5,
rms_norm: bool = False,
initializer_cfg=None,
fused_add_norm=False,
residual_in_fp32=False,
device=None,
dtype=None,
ft_seq_len=None,
pt_hw_seq_len=14,
if_bidirectional=False,
final_pool_type='none',
if_abs_pos_embed=False,
if_rope=False,
if_rope_residual=False,
flip_img_sequences_ratio=-1.,
if_bimamba=False,
bimamba_type="none",
mixer_type='mamba',
if_cls_token=False,
if_devide_out=False,
init_layer_scale=None,
use_double_cls_token=False,
use_middle_cls_token=False,
if_3d = False,
**kwargs):
factory_kwargs = {"device": device, "dtype": dtype}
# add factory_kwargs into kwargs
kwargs.update(factory_kwargs)
super().__init__()
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.if_bidirectional = if_bidirectional
self.final_pool_type = final_pool_type
self.if_abs_pos_embed = if_abs_pos_embed
self.if_rope = if_rope
self.if_rope_residual = if_rope_residual
self.flip_img_sequences_ratio = flip_img_sequences_ratio
self.if_cls_token = if_cls_token
self.use_double_cls_token = use_double_cls_token
self.use_middle_cls_token = use_middle_cls_token
self.num_tokens = 1 if if_cls_token else 0
self.mixer_type = mixer_type
# pretrain parameters
self.num_classes = num_classes
self.d_model = self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, stride=stride, in_chans=channels, embed_dim=embed_dim,
norm_layer= nn.LayerNorm if not self.if_cls_token else None) if not if_3d else PatchEmbed3D(
volume_size= img_size, patch_size=patch_size, in_chans=channels, embed_dim=embed_dim,
norm_layer= nn.LayerNorm if not self.if_cls_token else None
)
# self.patch_embed = PatchEmbedUniversial(
# img_size=img_size, in_chans=channels, embed_dim=embed_dim, norm_layer= nn.LayerNorm if not self.if_cls_token else None, if_3d=if_3d
# )
num_patches = self.patch_embed.num_patches
if if_cls_token:
if use_double_cls_token:
self.cls_token_head = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.cls_token_tail = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.num_tokens = 2
else:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
# self.num_tokens = 1
if if_abs_pos_embed:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, self.embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
if if_rope:
half_head_dim = embed_dim // 2
hw_seq_len = img_size // patch_size
self.rope = VisionRotaryEmbeddingFast(
dim=half_head_dim,
pt_seq_len=pt_hw_seq_len,
ft_seq_len=hw_seq_len,
if_3d=if_3d
)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
# TODO: release this comment
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
# import ipdb;ipdb.set_trace()
inter_dpr = [0.0] + dpr
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
# transformer blocks
self.blocks = nn.ModuleList(
[
create_block(
embed_dim,
ssm_cfg=ssm_cfg,
attn_layer_idx=attn_layer_idx,
attn_cfg=attn_cfg,
norm_epsilon=norm_epsilon,
rms_norm=rms_norm,
residual_in_fp32=residual_in_fp32,
fused_add_norm=fused_add_norm,
layer_idx=i,
if_bimamba=if_bimamba,
bimamba_type=bimamba_type,
drop_path=inter_dpr[i],
if_devide_out=if_devide_out,
init_layer_scale=init_layer_scale,
mixer_type= mixer_type,
**factory_kwargs,
)
for i in range(depth)
]
)
# output head
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
embed_dim, eps=norm_epsilon, **factory_kwargs
)
# self.pre_logits = nn.Identity()
# original init
self.patch_embed.apply(segm_init_weights)
self.head.apply(segm_init_weights)
if if_abs_pos_embed:
trunc_normal_(self.pos_embed, std=.02)
if if_cls_token:
if use_double_cls_token:
trunc_normal_(self.cls_token_head, std=.02)
trunc_normal_(self.cls_token_tail, std=.02)
else:
trunc_normal_(self.cls_token, std=.02)
# mamba init
self.apply(
partial(
_init_weights,
n_layer=depth,
**(initializer_cfg if initializer_cfg is not None else {}),
)
)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return {
i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
for i, layer in enumerate(self.blocks)
}
@torch.jit.ignore
def no_weight_decay(self):
return {"pos_embed", "cls_token", "dist_token", "cls_token_head", "cls_token_tail"}
@torch.jit.ignore()
def load_pretrained(self, checkpoint_path, prefix=""):
_load_weights(self, checkpoint_path, prefix)
def unpatchify_3d(self, x):
"""
x: (N, L, patch_size**3 *3)
imgs: (N, 3, H, W)
"""
p = self.patch_embed.patch_size[0]
l = h = w = round(x.shape[1] ** (1 / 3))
assert l * h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], l, h, w, p, p, p,
-1)) # Earlier 3 was hard-coded here. Maybe this way, we are more flexible with the number of channels
x = torch.einsum('nlhwrpqc->nclrhpwq', x)
volume = x.reshape(shape=(x.shape[0], -1, h * p, h * p, h * p))
return volume
def forward_features(self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to add the dist_token
x = self.patch_embed(x)
B, M, _ = x.shape
if self.if_cls_token:
if self.use_double_cls_token:
cls_token_head = self.cls_token_head.expand(B, -1, -1)
cls_token_tail = self.cls_token_tail.expand(B, -1, -1)
token_position = [0, M + 1]
x = torch.cat((cls_token_head, x, cls_token_tail), dim=1)
M = x.shape[1]
else:
if self.use_middle_cls_token:
cls_token = self.cls_token.expand(B, -1, -1)
token_position = M // 2
# add cls token in the middle
x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
elif if_random_cls_token_position:
cls_token = self.cls_token.expand(B, -1, -1)
token_position = random.randint(0, M)
x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
# print("token_position: ", token_position)
else:
cls_token = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
token_position = 0
x = torch.cat((cls_token, x), dim=1)
M = x.shape[1]
if self.if_abs_pos_embed:
# if new_grid_size[0] == self.patch_embed.grid_size[0] and new_grid_size[1] == self.patch_embed.grid_size[1]:
# x = x + self.pos_embed
# else:
# pos_embed = interpolate_pos_embed_online(
# self.pos_embed, self.patch_embed.grid_size, new_grid_size,0
# )
x = x + self.pos_embed
x = self.pos_drop(x)
if if_random_token_rank:
# 生成随机 shuffle 索引
shuffle_indices = torch.randperm(M)
if isinstance(token_position, list):
print("original value: ", x[0, token_position[0], 0], x[0, token_position[1], 0])
else:
print("original value: ", x[0, token_position, 0])
print("original token_position: ", token_position)
# 执行 shuffle
x = x[:, shuffle_indices, :]
if isinstance(token_position, list):
# 找到 cls token 在 shuffle 之后的新位置
new_token_position = [torch.where(shuffle_indices == token_position[i])[0].item() for i in range(len(token_position))]
token_position = new_token_position
else:
# 找到 cls token 在 shuffle 之后的新位置
token_position = torch.where(shuffle_indices == token_position)[0].item()
if isinstance(token_position, list):
print("new value: ", x[0, token_position[0], 0], x[0, token_position[1], 0])
else:
print("new value: ", x[0, token_position, 0])
print("new token_position: ", token_position)
if_flip_img_sequences = False
if self.flip_img_sequences_ratio > 0 and (self.flip_img_sequences_ratio - random.random()) > 1e-5:
x = x.flip([1])
if_flip_img_sequences = True
# mamba impl
residual = None
hidden_states = x
if not self.if_bidirectional:
for layer in self.blocks:
if if_flip_img_sequences and self.if_rope:
hidden_states = hidden_states.flip([1])
if residual is not None:
residual = residual.flip([1])
# rope about
if self.if_rope:
# print('hidden states shape: ', hidden_states.shape)
hidden_states = self.rope(hidden_states)
if residual is not None and self.if_rope_residual:
residual = self.rope(residual)
if if_flip_img_sequences and self.if_rope:
hidden_states = hidden_states.flip([1])
if residual is not None:
residual = residual.flip([1])
hidden_states, residual = layer(
hidden_states, residual, inference_params=inference_params
)
else:
# get two layers in a single for-loop
for i in range(len(self.blocks) // 2):
if self.if_rope:
hidden_states = self.rope(hidden_states)
if residual is not None and self.if_rope_residual:
residual = self.rope(residual)
hidden_states_f, residual_f = self.blocks[i * 2](
hidden_states, residual, inference_params=inference_params
)
hidden_states_b, residual_b = self.blocks[i * 2 + 1](
hidden_states.flip([1]), None if residual == None else residual.flip([1]), inference_params=inference_params
)
hidden_states = hidden_states_f + hidden_states_b.flip([1])
residual = residual_f + residual_b.flip([1])
if not self.fused_add_norm:
if residual is None:
residual = hidden_states
else:
residual = residual + self.drop_path(hidden_states)
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
else:
# Set prenorm=False here since we don't need the residual
hidden_states = layer_norm_fn(
self.drop_path(hidden_states),
self.norm_f.weight,
self.norm_f.bias,
residual=residual,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm_f.eps,
is_rms_norm=isinstance(self.norm_f, RMSNorm)
)
# return only cls token if it exists
if self.if_cls_token:
if self.use_double_cls_token:
return (hidden_states[:, token_position[0], :] + hidden_states[:, token_position[1], :]) / 2
else:
if self.use_middle_cls_token:
return hidden_states[:, token_position, :]
elif if_random_cls_token_position:
return hidden_states[:, token_position, :]
else:
return hidden_states[:, token_position, :]
if self.final_pool_type == 'none':
return hidden_states[:, -1, :]
elif self.final_pool_type == 'mean':
return hidden_states.mean(dim=1)
elif self.final_pool_type == 'max':
return hidden_states
elif self.final_pool_type == 'all':
return hidden_states
else:
raise NotImplementedError
def forward(self, x, return_features=False, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
x = self.forward_features(x, inference_params, if_random_cls_token_position=if_random_cls_token_position, if_random_token_rank=if_random_token_rank)
if return_features:
return x
x = self.head(x)
if self.final_pool_type == 'max':
x = x.max(dim=1)[0]
return x
@register_model
def vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False, **kwargs):
model = VisionMamba(
patch_size=16, embed_dim=192, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=True, if_rope_residual=True, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="to.do",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def vim_tiny_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False, **kwargs):
model = VisionMamba(
patch_size=16, stride=8, embed_dim=192, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=True, if_rope_residual=True, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="to.do",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False, **kwargs):
model = VisionMamba(
patch_size=16, embed_dim=384, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=True, if_rope_residual=True, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="to.do",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False, **kwargs):
model = VisionMamba(
patch_size=16, stride=8, embed_dim=384, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=True, if_rope_residual=True, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="to.do",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_div2(pretrained=False, **kwargs):
model = VisionMamba(
patch_size=16, embed_dim=384, depth=12, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean',
if_abs_pos_embed=True, if_rope=True, if_rope_residual=True, bimamba_type="v2", if_cls_token=False, if_devide_out=True, use_middle_cls_token=False, **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="to.do",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def vim_small_patch8_224_bimambav2_final_pool_mean_abs_pos_embed_div2(pretrained=False, **kwargs):
model = VisionMamba(
patch_size=8, stride= 8, embed_dim=384, depth=12, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean',
if_abs_pos_embed=True, if_rope=True, if_rope_residual=True, bimamba_type="v2", if_cls_token=False, if_devide_out=True, use_middle_cls_token=False, **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="to.do",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def vim_small_patch4_224_bimambav2_final_pool_mean_abs_pos_embed_div2(pretrained=False, **kwargs):
model = VisionMamba(
patch_size=4, stride = 4, embed_dim=384, depth=12, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean',
if_abs_pos_embed=True, if_rope=True, if_rope_residual=True, bimamba_type="v2", if_cls_token=False, if_devide_out=True, use_middle_cls_token=False, **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="to.do",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def vim_base_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_div2(pretrained=False, **kwargs):
model = VisionMamba(
patch_size=16, stride=8, embed_dim=768, depth=24, rms_norm=True, residual_in_fp32=True,
fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=True, if_rope_residual=True,
bimamba_type="v2", if_cls_token=False, if_devide_out=True, use_middle_cls_token=False, **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="to.do",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def vim_3D_base_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_div2(pretrained=False, **kwargs):
model = VisionMamba(
patch_size=16, stride=8, embed_dim=768, depth=12, rms_norm=True, residual_in_fp32=True,
fused_add_norm=True, if_abs_pos_embed=True, final_pool_type='mean', if_rope=True, if_rope_residual=True,
bimamba_type="v2", if_cls_token=False, if_devide_out=True, use_middle_cls_token=False,
if_3d= True,
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="to.do",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def vim_3D_small_patch16_stride16_224_bimambav2_final_pool_mean_abs_pos_embed_div2(pretrained=False, **kwargs):
model = VisionMamba(
patch_size=16, stride=16, embed_dim=384, depth=12, rms_norm=True, residual_in_fp32=True,
fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=True, if_rope_residual=True,
bimamba_type="v2", if_cls_token=False, if_devide_out=True, use_middle_cls_token=False, if_3d=True,
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="to.do",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def vim_3D_small_patch8_stride8_224_bimambav2_final_pool_none_abs_pos_embed_div2(pretrained=False, **kwargs):
model = VisionMamba(
patch_size=8, stride=8, embed_dim=384, depth=12, rms_norm=True, residual_in_fp32=True,
fused_add_norm=True, final_pool_type='none', if_abs_pos_embed=True, if_rope=True, if_rope_residual=True,
bimamba_type="v2", if_cls_token=False, if_devide_out=True, use_middle_cls_token=False, if_3d=True,
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="to.do",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def vim_3D_small_patch4_stride4_224_bimambav2_final_pool_mean_abs_pos_embed_div2(pretrained=False, **kwargs):
model = VisionMamba(
patch_size=4, stride=4, embed_dim=384, depth=12, rms_norm=True, residual_in_fp32=True,
fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=True, if_rope_residual=True,
bimamba_type="v2", if_cls_token=False, if_devide_out=True, use_middle_cls_token=False, if_3d=True,
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="to.do",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def vim_3D_small_patch8_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_div2(pretrained=False, **kwargs):
model = VisionMamba(
patch_size=8, stride=8, embed_dim=384, depth=12, rms_norm=True, residual_in_fp32=True,
fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=True, if_rope_residual=True,
bimamba_type="v2", if_cls_token=False, if_devide_out=True, use_middle_cls_token=False, if_3d=True,
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="to.do",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def vim_3D_small_patch8_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_div2(pretrained=False, **kwargs):
model = VisionMamba(
patch_size=8, stride=8, embed_dim=384, depth=12, rms_norm=True, residual_in_fp32=True,
fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=True, if_rope_residual=True,
bimamba_type="v2", if_cls_token=False, if_devide_out=True, use_middle_cls_token=False, if_3d=True,
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="to.do",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def vim_3D_small_patch32_stride32_224_bimambav2_final_pool_mean_abs_pos_embed_div2(pretrained=False, **kwargs):
model = VisionMamba(
patch_size=32, stride=32, embed_dim=384, depth=12, rms_norm=True, residual_in_fp32=True,
fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=True, if_rope_residual=True,
bimamba_type="v2", if_cls_token=False, if_devide_out=True, use_middle_cls_token=False, if_3d=True,
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="to.do",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def vim2_3D_small_patch32_stride32_final_pool_mean_abs_pos_embed_div2(pretrained=False, **kwargs):
model = VisionMamba(
patch_size=32, stride=32, embed_dim=512, depth=12, rms_norm=True, residual_in_fp32=True,
fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False,
bimamba_type="v2", if_cls_token=False, if_devide_out=True, use_middle_cls_token=False, if_3d=True,
mixer_type= 'mamba2',
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="to.do",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def vim2_3D_small_patch16_stride16_final_pool_mean_abs_pos_embed_div2(pretrained=False, **kwargs):
model = VisionMamba(
patch_size=16, stride=16, embed_dim=512, depth=12, rms_norm=True, residual_in_fp32=True,
fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False,
bimamba_type="v2", if_cls_token=False, if_devide_out=True, use_middle_cls_token=False, if_3d=True,
mixer_type= 'mamba2',
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="to.do",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def vim2_3D_small_patch4_stride4_final_pool_mean_abs_pos_embed_div2(pretrained=False, **kwargs):
model = VisionMamba(
patch_size=4, stride=4, embed_dim=512, depth=12, rms_norm=True, residual_in_fp32=True,
fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False,
bimamba_type="v2", if_cls_token=False, if_devide_out=True, use_middle_cls_token=False, if_3d=True,
mixer_type= 'mamba2',
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="to.do",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def vim_mha_3D_small_patch16_stride16_final_pool_mean_abs_pos_embed_div2(pretrained=False, **kwargs):
model = VisionMamba(
patch_size=16, stride=16, embed_dim=512, depth=12,
rms_norm=True,
residual_in_fp32=True,
fused_add_norm=True,
final_pool_type='mean',
if_abs_pos_embed=True,
if_rope=False,
if_rope_residual=False,
bimamba_type="v2",
if_cls_token=False,
if_devide_out=True,
use_middle_cls_token=False, if_3d=True,
mixer_type= 'mamba2',
attn_cfg={
"causal": False,
"d_conv": 4,
"head_dim": 128,
"num_heads": 16,
"out_proj_bias": False,
"qkv_proj_bias": False,
"rotary_emb_dim": 0,
},
attn_layer_idx=[3, 6, 9],
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="to.do",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model