-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
dsvt.py
616 lines (532 loc) · 31.3 KB
/
dsvt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from math import ceil
from pcdet.models.model_utils.dsvt_utils import get_window_coors, get_inner_win_inds_cuda, get_pooling_index, get_continous_inds
from pcdet.models.model_utils.dsvt_utils import PositionEmbeddingLearned
class DSVT(nn.Module):
'''Dynamic Sparse Voxel Transformer Backbone.
Args:
INPUT_LAYER: Config of input layer, which converts the output of vfe to dsvt input.
block_name (list[string]): Name of blocks for each stage. Length: stage_num.
set_info (list[list[int, int]]): A list of set config for each stage. Eelement i contains
[set_size, block_num], where set_size is the number of voxel in a set and block_num is the
number of blocks for stage i. Length: stage_num.
d_model (list[int]): Number of input channels for each stage. Length: stage_num.
nhead (list[int]): Number of attention heads for each stage. Length: stage_num.
dim_feedforward (list[int]): Dimensions of the feedforward network in set attention for each stage.
Length: stage num.
dropout (float): Drop rate of set attention.
activation (string): Name of activation layer in set attention.
reduction_type (string): Pooling method between stages. One of: "attention", "maxpool", "linear".
output_shape (tuple[int, int]): Shape of output bev feature.
conv_out_channel (int): Number of output channels.
'''
def __init__(self, model_cfg, **kwargs):
super().__init__()
self.model_cfg = model_cfg
self.input_layer = DSVTInputLayer(self.model_cfg.INPUT_LAYER)
block_name = self.model_cfg.block_name
set_info = self.model_cfg.set_info
d_model = self.model_cfg.d_model
nhead = self.model_cfg.nhead
dim_feedforward = self.model_cfg.dim_feedforward
dropout = self.model_cfg.dropout
activation = self.model_cfg.activation
self.reduction_type = self.model_cfg.get('reduction_type', 'attention')
# save GPU memory
self.use_torch_ckpt = self.model_cfg.get('USE_CHECKPOINT', False)
# Sparse Regional Attention Blocks
stage_num = len(block_name)
for stage_id in range(stage_num):
num_blocks_this_stage = set_info[stage_id][-1]
dmodel_this_stage = d_model[stage_id]
dfeed_this_stage = dim_feedforward[stage_id]
num_head_this_stage = nhead[stage_id]
block_name_this_stage = block_name[stage_id]
block_module = _get_block_module(block_name_this_stage)
block_list=[]
norm_list=[]
for i in range(num_blocks_this_stage):
block_list.append(
block_module(dmodel_this_stage, num_head_this_stage, dfeed_this_stage,
dropout, activation, batch_first=True)
)
norm_list.append(nn.LayerNorm(dmodel_this_stage))
self.__setattr__(f'stage_{stage_id}', nn.ModuleList(block_list))
self.__setattr__(f'residual_norm_stage_{stage_id}', nn.ModuleList(norm_list))
# apply pooling except the last stage
if stage_id < stage_num-1:
downsample_window = self.model_cfg.INPUT_LAYER.downsample_stride[stage_id]
dmodel_next_stage = d_model[stage_id+1]
pool_volume = torch.IntTensor(downsample_window).prod().item()
if self.reduction_type == 'linear':
cat_feat_dim = dmodel_this_stage * torch.IntTensor(downsample_window).prod().item()
self.__setattr__(f'stage_{stage_id}_reduction', Stage_Reduction_Block(cat_feat_dim, dmodel_next_stage))
elif self.reduction_type == 'maxpool':
self.__setattr__(f'stage_{stage_id}_reduction', torch.nn.MaxPool1d(pool_volume))
elif self.reduction_type == 'attention':
self.__setattr__(f'stage_{stage_id}_reduction', Stage_ReductionAtt_Block(dmodel_this_stage, pool_volume))
else:
raise NotImplementedError
self.num_shifts = [2] * stage_num
self.output_shape = self.model_cfg.output_shape
self.stage_num = stage_num
self.set_info = set_info
self.num_point_features = self.model_cfg.conv_out_channel
self._reset_parameters()
def forward(self, batch_dict):
'''
Args:
bacth_dict (dict):
The dict contains the following keys
- voxel_features (Tensor[float]): Voxel features after VFE. Shape of (N, d_model[0]),
where N is the number of input voxels.
- voxel_coords (Tensor[int]): Shape of (N, 4), corresponding voxel coordinates of each voxels.
Each row is (batch_id, z, y, x).
- ...
Returns:
bacth_dict (dict):
The dict contains the following keys
- pillar_features (Tensor[float]):
- voxel_coords (Tensor[int]):
- ...
'''
voxel_info = self.input_layer(batch_dict)
voxel_feat = voxel_info['voxel_feats_stage0']
set_voxel_inds_list = [[voxel_info[f'set_voxel_inds_stage{s}_shift{i}'] for i in range(self.num_shifts[s])] for s in range(self.stage_num)]
set_voxel_masks_list = [[voxel_info[f'set_voxel_mask_stage{s}_shift{i}'] for i in range(self.num_shifts[s])] for s in range(self.stage_num)]
pos_embed_list = [[[voxel_info[f'pos_embed_stage{s}_block{b}_shift{i}'] for i in range(self.num_shifts[s])] for b in range(self.set_info[s][1])] for s in range(self.stage_num)]
pooling_mapping_index = [voxel_info[f'pooling_mapping_index_stage{s+1}'] for s in range(self.stage_num-1)]
pooling_index_in_pool = [voxel_info[f'pooling_index_in_pool_stage{s+1}'] for s in range(self.stage_num-1)]
pooling_preholder_feats = [voxel_info[f'pooling_preholder_feats_stage{s+1}'] for s in range(self.stage_num-1)]
output = voxel_feat
block_id = 0
for stage_id in range(self.stage_num):
block_layers = self.__getattr__(f'stage_{stage_id}')
residual_norm_layers = self.__getattr__(f'residual_norm_stage_{stage_id}')
for i in range(len(block_layers)):
block = block_layers[i]
residual = output.clone()
if self.use_torch_ckpt==False:
output = block(output, set_voxel_inds_list[stage_id], set_voxel_masks_list[stage_id], pos_embed_list[stage_id][i], \
block_id=block_id)
else:
output = checkpoint(block, output, set_voxel_inds_list[stage_id], set_voxel_masks_list[stage_id], pos_embed_list[stage_id][i], block_id)
output = residual_norm_layers[i](output + residual)
block_id += 1
if stage_id < self.stage_num - 1:
# pooling
prepool_features = pooling_preholder_feats[stage_id].type_as(output)
pooled_voxel_num = prepool_features.shape[0]
pool_volume = prepool_features.shape[1]
prepool_features[pooling_mapping_index[stage_id], pooling_index_in_pool[stage_id]] = output
prepool_features = prepool_features.view(prepool_features.shape[0], -1)
if self.reduction_type == 'linear':
output = self.__getattr__(f'stage_{stage_id}_reduction')(prepool_features)
elif self.reduction_type == 'maxpool':
prepool_features = prepool_features.view(pooled_voxel_num, pool_volume, -1).permute(0, 2, 1)
output = self.__getattr__(f'stage_{stage_id}_reduction')(prepool_features).squeeze(-1)
elif self.reduction_type == 'attention':
prepool_features = prepool_features.view(pooled_voxel_num, pool_volume, -1).permute(0, 2, 1)
key_padding_mask = torch.zeros((pooled_voxel_num, pool_volume)).to(prepool_features.device).int()
output = self.__getattr__(f'stage_{stage_id}_reduction')(prepool_features, key_padding_mask)
else:
raise NotImplementedError
batch_dict['pillar_features'] = batch_dict['voxel_features'] = output
batch_dict['voxel_coords'] = voxel_info[f'voxel_coors_stage{self.stage_num - 1}']
return batch_dict
def _reset_parameters(self):
for name, p in self.named_parameters():
if p.dim() > 1 and 'scaler' not in name:
nn.init.xavier_uniform_(p)
class DSVTBlock(nn.Module):
''' Consist of two encoder layer, shift and shift back.
'''
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", batch_first=True):
super().__init__()
encoder_1 = DSVT_EncoderLayer(d_model, nhead, dim_feedforward, dropout,
activation, batch_first)
encoder_2 = DSVT_EncoderLayer(d_model, nhead, dim_feedforward, dropout,
activation, batch_first)
self.encoder_list = nn.ModuleList([encoder_1, encoder_2])
def forward(
self,
src,
set_voxel_inds_list,
set_voxel_masks_list,
pos_embed_list,
block_id,
):
num_shifts = 2
output = src
# TODO: bug to be fixed, mismatch of pos_embed
for i in range(num_shifts):
set_id = i
shift_id = block_id % 2
pos_embed_id = i
set_voxel_inds = set_voxel_inds_list[shift_id][set_id]
set_voxel_masks = set_voxel_masks_list[shift_id][set_id]
pos_embed = pos_embed_list[pos_embed_id]
layer = self.encoder_list[i]
output = layer(output, set_voxel_inds, set_voxel_masks, pos_embed)
return output
class DSVT_EncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", batch_first=True, mlp_dropout=0):
super().__init__()
self.win_attn = SetAttention(d_model, nhead, dropout, dim_feedforward, activation, batch_first, mlp_dropout)
self.norm = nn.LayerNorm(d_model)
self.d_model = d_model
def forward(self,src,set_voxel_inds,set_voxel_masks,pos=None):
identity = src
src = self.win_attn(src, pos, set_voxel_masks, set_voxel_inds)
src = src + identity
src = self.norm(src)
return src
class SetAttention(nn.Module):
def __init__(self, d_model, nhead, dropout, dim_feedforward=2048, activation="relu", batch_first=True, mlp_dropout=0):
super().__init__()
self.nhead = nhead
if batch_first:
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
else:
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(mlp_dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.d_model = d_model
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Identity()
self.dropout2 = nn.Identity()
self.activation = _get_activation_fn(activation)
def forward(self, src, pos=None, key_padding_mask=None, voxel_inds=None):
'''
Args:
src (Tensor[float]): Voxel features with shape (N, C), where N is the number of voxels.
pos (Tensor[float]): Position embedding vectors with shape (N, C).
key_padding_mask (Tensor[bool]): Mask for redundant voxels within set. Shape of (set_num, set_size).
voxel_inds (Tensor[int]): Voxel indexs for each set. Shape of (set_num, set_size).
Returns:
src (Tensor[float]): Voxel features.
'''
set_features = src[voxel_inds]
if pos is not None:
set_pos = pos[voxel_inds]
else:
set_pos = None
if pos is not None:
query = set_features + set_pos
key = set_features + set_pos
value = set_features
if key_padding_mask is not None:
src2 = self.self_attn(query, key, value, key_padding_mask)[0]
else:
src2 = self.self_attn(query, key, value)[0]
# map voxel featurs from set space to voxel space: (set_num, set_size, C) --> (N, C)
flatten_inds = voxel_inds.reshape(-1)
unique_flatten_inds, inverse = torch.unique(flatten_inds, return_inverse=True)
perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device)
inverse, perm = inverse.flip([0]), perm.flip([0])
perm = inverse.new_empty(unique_flatten_inds.size(0)).scatter_(0, inverse, perm)
src2 = src2.reshape(-1, self.d_model)[perm]
# FFN layer
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
class Stage_Reduction_Block(nn.Module):
def __init__(self, input_channel, output_channel):
super().__init__()
self.linear1 = nn.Linear(input_channel, output_channel, bias=False)
self.norm = nn.LayerNorm(output_channel)
def forward(self, x):
src = x
src = self.norm(self.linear1(x))
return src
class Stage_ReductionAtt_Block(nn.Module):
def __init__(self, input_channel, pool_volume):
super().__init__()
self.pool_volume = pool_volume
self.query_func = torch.nn.MaxPool1d(pool_volume)
self.norm = nn.LayerNorm(input_channel)
self.self_attn = nn.MultiheadAttention(input_channel, 8, batch_first=True)
self.pos_embedding = nn.Parameter(torch.randn(pool_volume, input_channel))
nn.init.normal_(self.pos_embedding, std=.01)
def forward(self, x, key_padding_mask):
# x: [voxel_num, c_dim, pool_volume]
src = self.query_func(x).permute(0, 2, 1) # voxel_num, 1, c_dim
key = value = x.permute(0, 2, 1)
key = key + self.pos_embedding.unsqueeze(0).repeat(src.shape[0], 1, 1)
query = src.clone()
output = self.self_attn(query, key, value, key_padding_mask)[0]
src = self.norm(output + src).squeeze(1)
return src
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return torch.nn.functional.relu
if activation == "gelu":
return torch.nn.functional.gelu
if activation == "glu":
return torch.nn.functional.glu
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
def _get_block_module(name):
"""Return an block module given a string"""
if name == "DSVTBlock":
return DSVTBlock
raise RuntimeError(F"This Block not exist.")
class DSVTInputLayer(nn.Module):
'''
This class converts the output of vfe to dsvt input.
We do in this class:
1. Window partition: partition voxels to non-overlapping windows.
2. Set partition: generate non-overlapped and size-equivalent local sets within each window.
3. Pre-compute the downsample infomation between two consecutive stages.
4. Pre-compute the position embedding vectors.
Args:
sparse_shape (tuple[int, int, int]): Shape of input space (xdim, ydim, zdim).
window_shape (list[list[int, int, int]]): Window shapes (winx, winy, winz) in different stages. Length: stage_num.
downsample_stride (list[list[int, int, int]]): Downsample strides between two consecutive stages.
Element i is [ds_x, ds_y, ds_z], which is used between stage_i and stage_{i+1}. Length: stage_num - 1.
d_model (list[int]): Number of input channels for each stage. Length: stage_num.
set_info (list[list[int, int]]): A list of set config for each stage. Eelement i contains
[set_size, block_num], where set_size is the number of voxel in a set and block_num is the
number of blocks for stage i. Length: stage_num.
hybrid_factor (list[int, int, int]): Control the window shape in different blocks.
e.g. for block_{0} and block_{1} in stage_0, window shapes are [win_x, win_y, win_z] and
[win_x * h[0], win_y * h[1], win_z * h[2]] respectively.
shift_list (list): Shift window. Length: stage_num.
normalize_pos (bool): Whether to normalize coordinates in position embedding.
'''
def __init__(self, model_cfg):
super().__init__()
self.model_cfg = model_cfg
self.sparse_shape = self.model_cfg.sparse_shape
self.window_shape = self.model_cfg.window_shape
self.downsample_stride = self.model_cfg.downsample_stride
self.d_model = self.model_cfg.d_model
self.set_info = self.model_cfg.set_info
self.stage_num = len(self.d_model)
self.hybrid_factor = self.model_cfg.hybrid_factor
self.window_shape = [[self.window_shape[s_id], [self.window_shape[s_id][coord_id] * self.hybrid_factor[coord_id] \
for coord_id in range(3)]] for s_id in range(self.stage_num)]
self.shift_list = self.model_cfg.shifts_list
self.normalize_pos = self.model_cfg.normalize_pos
self.num_shifts = [2,] * len(self.window_shape)
self.sparse_shape_list = [self.sparse_shape]
# compute sparse shapes for each stage
for ds_stride in self.downsample_stride:
last_sparse_shape = self.sparse_shape_list[-1]
self.sparse_shape_list.append((ceil(last_sparse_shape[0]/ds_stride[0]), ceil(last_sparse_shape[1]/ds_stride[1]), ceil(last_sparse_shape[2]/ds_stride[2])))
# position embedding layers
self.posembed_layers = nn.ModuleList()
for i in range(len(self.set_info)):
input_dim = 3 if self.sparse_shape_list[i][-1] > 1 else 2
stage_posembed_layers = nn.ModuleList()
for j in range(self.set_info[i][1]):
block_posembed_layers = nn.ModuleList()
for s in range(self.num_shifts[i]):
block_posembed_layers.append(PositionEmbeddingLearned(input_dim, self.d_model[i]))
stage_posembed_layers.append(block_posembed_layers)
self.posembed_layers.append(stage_posembed_layers)
def forward(self, batch_dict):
'''
Args:
bacth_dict (dict):
The dict contains the following keys
- voxel_features (Tensor[float]): Voxel features after VFE with shape (N, d_model[0]),
where N is the number of input voxels.
- voxel_coords (Tensor[int]): Shape of (N, 4), corresponding voxel coordinates of each voxels.
Each row is (batch_id, z, y, x).
- ...
Returns:
voxel_info (dict):
The dict contains the following keys
- voxel_coors_stage{i} (Tensor[int]): Shape of (N_i, 4). N is the number of voxels in stage_i.
Each row is (batch_id, z, y, x).
- set_voxel_inds_stage{i}_shift{j} (Tensor[int]): Set partition index with shape (2, set_num, set_info[i][0]).
2 indicates x-axis partition and y-axis partition.
- set_voxel_mask_stage{i}_shift{i} (Tensor[bool]): Key mask used in set attention with shape (2, set_num, set_info[i][0]).
- pos_embed_stage{i}_block{i}_shift{i} (Tensor[float]): Position embedding vectors with shape (N_i, d_model[i]). N_i is the
number of remain voxels in stage_i;
- pooling_mapping_index_stage{i} (Tensor[int]): Pooling region index used in pooling operation between stage_{i-1} and stage_{i}
with shape (N_{i-1}).
- pooling_index_in_pool_stage{i} (Tensor[int]): Index inner region with shape (N_{i-1}). Combined with pooling_mapping_index_stage{i},
we can map each voxel in satge_{i-1} to pooling_preholder_feats_stage{i}, which are input of downsample operation.
- pooling_preholder_feats_stage{i} (Tensor[int]): Preholder features initial with value 0.
Shape of (N_{i}, downsample_stride[i-1].prob(), d_moel[i-1]), where prob() returns the product of all elements.
- ...
'''
voxel_feats = batch_dict['voxel_features']
voxel_coors = batch_dict['voxel_coords'].long()
voxel_info = {}
voxel_info['voxel_feats_stage0'] = voxel_feats.clone()
voxel_info['voxel_coors_stage0'] = voxel_coors.clone()
for stage_id in range(self.stage_num):
# window partition of corrsponding stage-map
voxel_info = self.window_partition(voxel_info, stage_id)
# generate set id of corrsponding stage-map
voxel_info = self.get_set(voxel_info, stage_id)
for block_id in range(self.set_info[stage_id][1]):
for shift_id in range(self.num_shifts[stage_id]):
voxel_info[f'pos_embed_stage{stage_id}_block{block_id}_shift{shift_id}'] = \
self.get_pos_embed(voxel_info[f'coors_in_win_stage{stage_id}_shift{shift_id}'], stage_id, block_id, shift_id)
# compute pooling information
if stage_id < self.stage_num - 1:
voxel_info = self.subm_pooling(voxel_info, stage_id)
return voxel_info
@torch.no_grad()
def subm_pooling(self, voxel_info, stage_id):
# x,y,z stride
cur_stage_downsample = self.downsample_stride[stage_id]
# batch_win_coords is from 1 of x, y
batch_win_inds, _, index_in_win, batch_win_coors = get_pooling_index(voxel_info[f'voxel_coors_stage{stage_id}'], self.sparse_shape_list[stage_id], cur_stage_downsample)
# compute pooling mapping index
unique_batch_win_inds, contiguous_batch_win_inds = torch.unique(batch_win_inds, return_inverse=True)
voxel_info[f'pooling_mapping_index_stage{stage_id+1}'] = contiguous_batch_win_inds
# generate empty placeholder features
placeholder_prepool_feats = voxel_info[f'voxel_feats_stage0'].new_zeros((len(unique_batch_win_inds),
torch.prod(torch.IntTensor(cur_stage_downsample)).item(), self.d_model[stage_id]))
voxel_info[f'pooling_index_in_pool_stage{stage_id+1}'] = index_in_win
voxel_info[f'pooling_preholder_feats_stage{stage_id+1}'] = placeholder_prepool_feats
# compute pooling coordinates
unique, inverse = unique_batch_win_inds.clone(), contiguous_batch_win_inds.clone()
perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device)
inverse, perm = inverse.flip([0]), perm.flip([0])
perm = inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm)
pool_coors = batch_win_coors[perm]
voxel_info[f'voxel_coors_stage{stage_id+1}'] = pool_coors
return voxel_info
def get_set(self, voxel_info, stage_id):
'''
This is one of the core operation of DSVT.
Given voxels' window ids and relative-coords inner window, we partition them into window-bounded and size-equivalent local sets.
To make it clear and easy to follow, we do not use loop to process two shifts.
Args:
voxel_info (dict):
The dict contains the following keys
- batch_win_inds_s{i} (Tensor[float]): Windows indexs of each voxel with shape (N), computed by 'window_partition'.
- coors_in_win_shift{i} (Tensor[int]): Relative-coords inner window of each voxel with shape (N, 3), computed by 'window_partition'.
Each row is (z, y, x).
- ...
Returns:
See from 'forward' function.
'''
batch_win_inds_shift0 = voxel_info[f'batch_win_inds_stage{stage_id}_shift0']
coors_in_win_shift0 = voxel_info[f'coors_in_win_stage{stage_id}_shift0']
set_voxel_inds_shift0 = self.get_set_single_shift(batch_win_inds_shift0, stage_id, shift_id=0, coors_in_win=coors_in_win_shift0)
voxel_info[f'set_voxel_inds_stage{stage_id}_shift0'] = set_voxel_inds_shift0
# compute key masks, voxel duplication must happen continuously
prefix_set_voxel_inds_s0 = torch.roll(set_voxel_inds_shift0.clone(), shifts=1, dims=-1)
prefix_set_voxel_inds_s0[ :, :, 0] = -1
set_voxel_mask_s0 = (set_voxel_inds_shift0 == prefix_set_voxel_inds_s0)
voxel_info[f'set_voxel_mask_stage{stage_id}_shift0'] = set_voxel_mask_s0
batch_win_inds_shift1 = voxel_info[f'batch_win_inds_stage{stage_id}_shift1']
coors_in_win_shift1 = voxel_info[f'coors_in_win_stage{stage_id}_shift1']
set_voxel_inds_shift1 = self.get_set_single_shift(batch_win_inds_shift1, stage_id, shift_id=1, coors_in_win=coors_in_win_shift1)
voxel_info[f'set_voxel_inds_stage{stage_id}_shift1'] = set_voxel_inds_shift1
# compute key masks, voxel duplication must happen continuously
prefix_set_voxel_inds_s1 = torch.roll(set_voxel_inds_shift1.clone(), shifts=1, dims=-1)
prefix_set_voxel_inds_s1[ :, :, 0] = -1
set_voxel_mask_s1 = (set_voxel_inds_shift1 == prefix_set_voxel_inds_s1)
voxel_info[f'set_voxel_mask_stage{stage_id}_shift1'] = set_voxel_mask_s1
return voxel_info
def get_set_single_shift(self, batch_win_inds, stage_id, shift_id=None, coors_in_win=None):
device = batch_win_inds.device
# the number of voxels assigned to a set
voxel_num_set = self.set_info[stage_id][0]
# max number of voxels in a window
max_voxel = self.window_shape[stage_id][shift_id][0] * self.window_shape[stage_id][shift_id][1] * self.window_shape[stage_id][shift_id][2]
# get unique set indexs
contiguous_win_inds = torch.unique(batch_win_inds, return_inverse=True)[1]
voxelnum_per_win = torch.bincount(contiguous_win_inds)
win_num = voxelnum_per_win.shape[0]
setnum_per_win_float = voxelnum_per_win / voxel_num_set
setnum_per_win = torch.ceil(setnum_per_win_float).long()
set_win_inds, set_inds_in_win = get_continous_inds(setnum_per_win)
# compution of Eq.3 in 'DSVT: Dynamic Sparse Voxel Transformer with Rotated Sets' - https://arxiv.org/abs/2301.06051,
# for each window, we can get voxel indexs belong to different sets.
offset_idx = set_inds_in_win[:,None].repeat(1, voxel_num_set) * voxel_num_set
base_idx = torch.arange(0, voxel_num_set, 1, device=device)
base_select_idx = offset_idx + base_idx
base_select_idx = base_select_idx * voxelnum_per_win[set_win_inds][:,None]
base_select_idx = base_select_idx.double() / (setnum_per_win[set_win_inds] * voxel_num_set)[:,None].double()
base_select_idx = torch.floor(base_select_idx)
# obtain unique indexs in whole space
select_idx = base_select_idx
select_idx = select_idx + set_win_inds.view(-1, 1) * max_voxel
# this function will return unordered inner window indexs of each voxel
inner_voxel_inds = get_inner_win_inds_cuda(contiguous_win_inds)
global_voxel_inds = contiguous_win_inds * max_voxel + inner_voxel_inds
_, order1 = torch.sort(global_voxel_inds)
# get y-axis partition results
global_voxel_inds_sorty = contiguous_win_inds * max_voxel + \
coors_in_win[:,1] * self.window_shape[stage_id][shift_id][0] * self.window_shape[stage_id][shift_id][2] + \
coors_in_win[:,2] * self.window_shape[stage_id][shift_id][2] + \
coors_in_win[:,0]
_, order2 = torch.sort(global_voxel_inds_sorty)
inner_voxel_inds_sorty = -torch.ones_like(inner_voxel_inds)
inner_voxel_inds_sorty.scatter_(dim=0, index=order2, src=inner_voxel_inds[order1]) # get y-axis ordered inner window indexs of each voxel
voxel_inds_in_batch_sorty = inner_voxel_inds_sorty + max_voxel * contiguous_win_inds
voxel_inds_padding_sorty = -1 * torch.ones((win_num * max_voxel), dtype=torch.long, device=device)
voxel_inds_padding_sorty[voxel_inds_in_batch_sorty] = torch.arange(0, voxel_inds_in_batch_sorty.shape[0], dtype=torch.long, device=device)
set_voxel_inds_sorty = voxel_inds_padding_sorty[select_idx.long()]
# get x-axis partition results
global_voxel_inds_sortx = contiguous_win_inds * max_voxel + \
coors_in_win[:,2] * self.window_shape[stage_id][shift_id][1] * self.window_shape[stage_id][shift_id][2] + \
coors_in_win[:,1] * self.window_shape[stage_id][shift_id][2] + \
coors_in_win[:,0]
_, order2 = torch.sort(global_voxel_inds_sortx)
inner_voxel_inds_sortx = -torch.ones_like(inner_voxel_inds)
inner_voxel_inds_sortx.scatter_(dim=0,index=order2, src=inner_voxel_inds[order1]) # get x-axis ordered inner window indexs of each voxel
voxel_inds_in_batch_sortx = inner_voxel_inds_sortx + max_voxel * contiguous_win_inds
voxel_inds_padding_sortx = -1 * torch.ones((win_num * max_voxel), dtype=torch.long, device=device)
voxel_inds_padding_sortx[voxel_inds_in_batch_sortx] = torch.arange(0, voxel_inds_in_batch_sortx.shape[0], dtype=torch.long, device=device)
set_voxel_inds_sortx = voxel_inds_padding_sortx[select_idx.long()]
all_set_voxel_inds = torch.stack((set_voxel_inds_sorty, set_voxel_inds_sortx), dim=0)
return all_set_voxel_inds
@torch.no_grad()
def window_partition(self, voxel_info, stage_id):
for i in range(2):
batch_win_inds, coors_in_win = get_window_coors(voxel_info[f'voxel_coors_stage{stage_id}'],
self.sparse_shape_list[stage_id], self.window_shape[stage_id][i], i == 1, self.shift_list[stage_id][i])
voxel_info[f'batch_win_inds_stage{stage_id}_shift{i}'] = batch_win_inds
voxel_info[f'coors_in_win_stage{stage_id}_shift{i}'] = coors_in_win
return voxel_info
def get_pos_embed(self, coors_in_win, stage_id, block_id, shift_id):
'''
Args:
coors_in_win: shape=[N, 3], order: z, y, x
'''
# [N,]
window_shape = self.window_shape[stage_id][shift_id]
embed_layer = self.posembed_layers[stage_id][block_id][shift_id]
if len(window_shape) == 2:
ndim = 2
win_x, win_y = window_shape
win_z = 0
elif window_shape[-1] == 1:
ndim = 2
win_x, win_y = window_shape[:2]
win_z = 0
else:
win_x, win_y, win_z = window_shape
ndim = 3
assert coors_in_win.size(1) == 3
z, y, x = coors_in_win[:, 0] - win_z/2, coors_in_win[:, 1] - win_y/2, coors_in_win[:, 2] - win_x/2
if self.normalize_pos:
x = x / win_x * 2 * 3.1415 #[-pi, pi]
y = y / win_y * 2 * 3.1415 #[-pi, pi]
z = z / win_z * 2 * 3.1415 #[-pi, pi]
if ndim==2:
location = torch.stack((x, y), dim=-1)
else:
location = torch.stack((x, y, z), dim=-1)
pos_embed = embed_layer(location)
return pos_embed