-
Notifications
You must be signed in to change notification settings - Fork 7
/
conv.py
1238 lines (1160 loc) · 55.8 KB
/
conv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2021 Yan Yan
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import time
import sys
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from torch import nn
from torch.nn import init
from torch.nn.parameter import Parameter
from spconv import pytorch as spconv
from spconv.core import ConvAlgo
from spconv.debug_utils import spconv_save_debug_data
from spconv.pytorch import functional as Fsp
from spconv.pytorch import ops
from spconv.cppconstants import CPU_ONLY_BUILD
from spconv.pytorch.core import IndiceData, SparseConvTensor, ImplicitGemmIndiceData, expand_nd
from spconv.pytorch.modules import SparseModule
from spconv.constants import FILTER_HWIO
from spconv.utils import nullcontext
from torch.nn.init import calculate_gain
from spconv.pytorch.utils import split_voxels, check_repeat
class SparseConvolution(SparseModule):
__constants__ = [
'stride', 'padding', 'dilation', 'groups', 'bias', 'subm', 'inverse',
'transposed', 'output_padding'
]
def __init__(self,
ndim: int,
in_channels: int,
out_channels: int,
kernel_size: Union[int, List[int], Tuple[int, ...]] = 3,
stride: Union[int, List[int], Tuple[int, ...]] = 1,
padding: Union[int, List[int], Tuple[int, ...]] = 0,
dilation: Union[int, List[int], Tuple[int, ...]] = 1,
groups: Union[int, List[int], Tuple[int, ...]] = 1,
bias: bool = True,
subm: bool = False,
spss: bool = False,
sprs: bool = False,
focal: bool = False,
output_padding: Union[int, List[int], Tuple[int, ...]] = 0,
transposed: bool = False,
inverse: bool = False,
indice_key: Optional[str] = None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
position_embedding: bool = False,
name=None,
spatial_groups: Union[int, List[int], Tuple[int, ...]] = 1):
super(SparseConvolution, self).__init__(name=name)
assert in_channels % groups == 0 and out_channels % groups == 0, "channels should be divisible by groups"
if spatial_groups>1:
assert subm, "spatial wise groups only support submanifold conv for now"
kernel_size_ori = kernel_size
kernel_size = spatial_groups
else:
kernel_size_ori = -1
self.spatial_groups = spatial_groups
self.ndim = ndim
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = expand_nd(ndim, kernel_size)
self.kernel_size_ori = expand_nd(ndim, kernel_size_ori)
self.stride = expand_nd(ndim, stride)
kv = int(np.prod(self.kernel_size))
kv_stride = int(np.prod(self.stride))
self.dilation = expand_nd(ndim, dilation)
self.padding = expand_nd(ndim, padding)
self.conv1x1 = kv == 1
# TODO we should deprecate support for ksize == 1 but stride != 1.
if not subm:
self.conv1x1 &= kv_stride == 1
if spss or focal:
assert not self.conv1x1, "SPSS not support conv1x1 now"
if self.conv1x1:
assert self.padding == [
0
] * ndim, "padding must be zero for 1x1 conv (k=1,s=1)"
self.transposed = transposed
self.inverse = inverse
self.output_padding = expand_nd(ndim, output_padding)
self.groups = groups
self.subm = subm
self.spss = spss
self.sprs = sprs
self.focal = focal
self.indice_key = indice_key
self.position_embedding = Parameter(torch.Tensor(kernel_size_ori**ndim, in_channels)) if position_embedding else None
if algo is None:
# TODO spss and focal only support native algorithom now
if self.spss or self.focal:
algo = ConvAlgo.Native
elif kv <= 32 and not CPU_ONLY_BUILD:
if kv < 8:
algo = ConvAlgo.MaskImplicitGemm
else:
algo = ConvAlgo.MaskImplicitGemm
else:
algo = ConvAlgo.Native
if self.spss or self.focal:
assert algo == ConvAlgo.Native, "implict gemm doesn't support spss for now"
assert in_channels==out_channels, "input channels must equal to output channels in SPSS"
assert not (self.spss and self.focal), "spss and focal can not be used concurrently"
if groups>1 or self.sprs:
assert algo == ConvAlgo.Native, "channel wise groups don't support gemm for now"
if kv > 32:
assert algo == ConvAlgo.Native, "implicit gemm don't support kv >= 32 for now"
if CPU_ONLY_BUILD:
assert algo == ConvAlgo.Native, "cpu only build only support native algorithm"
self.algo = algo
self.fp32_accum = fp32_accum
# self.algo = ConvAlgo.Native
if self.algo == ConvAlgo.Native:
if FILTER_HWIO:
# RSCK
self.weight = Parameter(
torch.Tensor(*self.kernel_size, in_channels//groups, out_channels))
else:
# RSKC
self.weight = Parameter(
torch.Tensor(*self.kernel_size, out_channels, in_channels//groups))
else:
# KRSC
self.weight = Parameter(
torch.Tensor(out_channels, *self.kernel_size, in_channels//groups))
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def extra_repr(self):
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}')
if self.padding != (0, ) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1, ) * len(self.dilation):
s += ', dilation={dilation}'
if self.output_padding != (0, ) * len(self.output_padding):
s += ', output_padding={output_padding}'
if self.groups != 1:
s += ', groups={groups}'
if self.bias is None:
s += ', bias=False'
if self.algo is not None:
s += f', algo={self.algo}'
return s.format(**self.__dict__)
def _calculate_fan_in_and_fan_out(self):
receptive_field_size = 1
# math.prod is not always available, accumulate the product manually
# we could use functools.reduce but that is not supported by TorchScript
for s in self.kernel_size:
receptive_field_size *= s
fan_in = self.in_channels * receptive_field_size
fan_out = self.out_channels * receptive_field_size
return fan_in, fan_out
def _calculate_correct_fan(self, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError(
"Mode {} not supported, please use one of {}".format(
mode, valid_modes))
fan_in, fan_out = self._calculate_fan_in_and_fan_out()
return fan_in if mode == 'fan_in' else fan_out
def _custom_kaiming_uniform_(self,
tensor,
a=0,
mode='fan_in',
nonlinearity='leaky_relu'):
r"""same as torch.init.kaiming_uniform_, with KRSC layout support
"""
fan = self._calculate_correct_fan(mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(
3.0) * std # Calculate uniform bounds from standard deviation
with torch.no_grad():
return tensor.uniform_(-bound, bound)
def reset_parameters(self):
self._custom_kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = self._calculate_fan_in_and_fan_out()
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def forward(self, input: SparseConvTensor, mask=None, ori_feat_num=-1, group_map=None):
assert isinstance(input, SparseConvTensor)
assert input.features.shape[
1] == self.in_channels, "channel size mismatch"
features = input.features
device = features.device
indices = input.indices
spatial_shape = input.spatial_shape
batch_size = input.batch_size
if self.spss or self.sprs:
features.device == mask.device, "mask should be in the same device with features"
if (not self.subm) and (not self.spss) and (not self.focal) and (not self.sprs):
if self.transposed:
out_spatial_shape = ops.get_deconv_output_size(
spatial_shape, self.kernel_size, self.stride, self.padding,
self.dilation, self.output_padding)
else:
out_spatial_shape = ops.get_conv_output_size(
spatial_shape, self.kernel_size, self.stride, self.padding,
self.dilation)
else:
out_spatial_shape = spatial_shape
# print(self._sparse_unique_name, spatial_shape, out_spatial_shape)
# input.update_grid(out_spatial_shape)
# t = time.time()
out_tensor = input.shadow_copy()
if input.benchmark:
if self.name is None:
raise ValueError(
"you need to assign name to spmodules before benchmark (spconv.utils.bench.assign_name_to_spmod)"
)
if self.name not in input.benchmark_record:
input.benchmark_record[self.name] = {
"type": "SparseConvolution",
"indice_gen_time": [],
"time": [],
"num_points": [],
"num_out_points": [],
"params": {
"kernel_size": self.kernel_size,
"stride": self.stride,
"padding": self.padding,
"dilation": self.dilation,
"output_padding": self.output_padding,
"subm": self.subm,
"transposed": self.transposed,
"input_channels": self.in_channels,
"out_channels": self.out_channels,
}
}
if self.conv1x1:
if FILTER_HWIO:
features = torch.mm(
input.features,
self.weight.view(self.out_channels, self.in_channels).T)
else:
features = torch.mm(
input.features,
self.weight.view(self.in_channels, self.out_channels))
if self.bias is not None:
features += self.bias
out_tensor = out_tensor.replace_feature(features)
# padding may change spatial shape of conv 1x1.
out_tensor.spatial_shape = out_spatial_shape
return out_tensor
indice_dict = input.indice_dict.copy()
algo = self.algo
if self.indice_key is not None:
datas = input.find_indice_pair(self.indice_key)
if datas is not None:
assert not self.spss and not self.focal and not self.sprs, "spss, sprs and focal can not reuse previous indice_key"
msg = "due to limitation of pytorch, you must provide same algo to layers share same indice key."
assert algo == datas.algo, msg
# algo = datas.algo
profile_ctx = nullcontext()
if input._timer is not None and self._sparse_unique_name:
profile_ctx = input._timer.namespace(self._sparse_unique_name)
with profile_ctx:
if algo == ConvAlgo.Native:
datas = input.find_indice_pair(self.indice_key)
if datas is not None:
assert isinstance(datas, IndiceData)
if self.inverse:
assert datas is not None and self.indice_key is not None
assert datas.is_subm is False, "inverse conv can only be used with standard conv and pool ops."
outids = datas.indices
indice_pairs = datas.indice_pairs
indice_pair_num = datas.indice_pair_num
out_spatial_shape = datas.spatial_shape
assert datas.ksize == self.kernel_size, "inverse conv must have same kernel size as its couple conv"
else:
if self.indice_key is not None and datas is not None:
outids = datas.out_indices
indice_pairs = datas.indice_pairs
indice_pair_num = datas.indice_pair_num
assert self.subm, "only support the reuse of indices for subm"
self._check_subm_reuse_valid(input, spatial_shape,
datas)
else:
if input.benchmark:
torch.cuda.synchronize()
t = time.time()
try:
outids, indice_pairs, indice_pair_num = ops.get_indice_pairs(
indices, batch_size, spatial_shape, algo,
self.kernel_size, self.kernel_size_ori,
self.stride, self.padding, self.dilation, self.output_padding, self.subm,
self.spss, self.sprs, self.focal, self.transposed, mask=mask, ori_feat_num=ori_feat_num, group_map=group_map)
except Exception as e:
msg = "[Exception|native_pair]"
msg += f"indices={indices.shape},bs={batch_size},ss={spatial_shape},"
msg += f"algo={algo},ksize={self.kernel_size},stride={self.stride},"
msg += f"padding={self.padding},dilation={self.dilation},subm={self.subm},"
msg += f"transpose={self.transposed}"
print(msg, file=sys.stderr)
spconv_save_debug_data(indices)
raise e
if input.benchmark:
torch.cuda.synchronize()
interval = time.time() - t
out_tensor.benchmark_record[
self.name]["indice_gen_time"].append(interval)
indice_data = IndiceData(outids,
indices,
indice_pairs,
indice_pair_num,
spatial_shape,
out_spatial_shape,
is_subm=self.subm,
algo=algo,
ksize=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation)
if self.indice_key is not None:
msg = f"your indice key {self.indice_key} already exists in this sparse tensor."
assert self.indice_key not in indice_dict, msg
indice_dict[self.indice_key] = indice_data
if input.benchmark:
torch.cuda.synchronize()
t = time.time()
indice_pairs_calc = indice_pairs
if indice_pairs.device != features.device:
indice_pairs_calc = indice_pairs.to(features.device)
if self.subm:
conv_func = Fsp.indice_subm_conv_groups if self.groups>1 else Fsp.indice_subm_conv
out_features = conv_func(
features, self.weight, indice_pairs_calc,
indice_pair_num, outids.shape[0], algo, self.groups, self.spatial_groups, self.position_embedding, input._timer)
elif self.spss or self.focal:
conv_func = Fsp.indice_conv_groups if self.groups>1 else Fsp.indice_conv
out_features = conv_func(
features, self.weight,
indice_pairs_calc,
indice_pair_num,
outids.shape[0], algo,
self.groups, input._timer)
else:
if self.inverse:
out_features = Fsp.indice_inverse_conv(
features, self.weight, indice_pairs_calc,
indice_pair_num, outids.shape[0], algo)
else:
conv_func = Fsp.indice_conv_groups if self.groups>1 else Fsp.indice_conv
out_features = conv_func(features, self.weight,
indice_pairs_calc,
indice_pair_num,
outids.shape[0], algo,
self.groups, input._timer)
else:
datas = input.find_indice_pair(self.indice_key)
if datas is not None:
assert isinstance(datas, ImplicitGemmIndiceData)
if self.inverse:
assert datas is not None and self.indice_key is not None
assert datas.is_subm is False, "inverse conv can only be used with standard conv and pool ops."
outids = datas.indices
pair_fwd = datas.pair_bwd
pair_bwd = datas.pair_fwd
pair_mask_fwd_splits = datas.pair_mask_bwd_splits
pair_mask_bwd_splits = datas.pair_mask_fwd_splits
mask_argsort_fwd_splits = datas.mask_argsort_bwd_splits
mask_argsort_bwd_splits = datas.mask_argsort_fwd_splits
masks = datas.masks
out_spatial_shape = datas.spatial_shape
assert datas.ksize == self.kernel_size, "inverse conv must have same kernel size as its couple conv"
else:
if self.indice_key is not None and datas is not None:
outids = datas.out_indices
pair_fwd = datas.pair_fwd
pair_bwd = datas.pair_bwd
pair_mask_fwd_splits = datas.pair_mask_fwd_splits
pair_mask_bwd_splits = datas.pair_mask_bwd_splits
mask_argsort_fwd_splits = datas.mask_argsort_fwd_splits
mask_argsort_bwd_splits = datas.mask_argsort_bwd_splits
masks = datas.masks
assert self.subm, "only support reuse subm indices"
self._check_subm_reuse_valid(input, spatial_shape,
datas)
else:
with input._timer.namespace("gen_pairs"):
# we need to gen bwd indices for regular conv
# because it may be inversed.
try:
res = ops.get_indice_pairs_implicit_gemm(
indices,
batch_size,
spatial_shape,
algo,
ksize=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
out_padding=self.output_padding,
subm=self.subm,
transpose=self.transposed,
is_train=(not self.subm) or self.training,
alloc=input.thrust_allocator,
timer=input._timer)
except Exception as e:
msg = "[Exception|implicit_gemm_pair]"
msg += f"indices={indices.shape},bs={batch_size},ss={spatial_shape},"
msg += f"algo={algo},ksize={self.kernel_size},stride={self.stride},"
msg += f"padding={self.padding},dilation={self.dilation},subm={self.subm},"
msg += f"transpose={self.transposed}"
print(msg, file=sys.stderr)
spconv_save_debug_data(indices)
raise e
outids = res[0]
num_inds_per_loc = res[1]
pair_fwd = res[2]
pair_bwd = res[3]
pair_mask_fwd_splits = res[4]
pair_mask_bwd_splits = res[5]
mask_argsort_fwd_splits = res[6]
mask_argsort_bwd_splits = res[7]
masks = res[8]
if self.indice_key is not None:
indice_data = ImplicitGemmIndiceData(
outids,
indices,
pair_fwd,
pair_bwd,
pair_mask_fwd_splits=pair_mask_fwd_splits,
pair_mask_bwd_splits=pair_mask_bwd_splits,
mask_argsort_fwd_splits=mask_argsort_fwd_splits,
mask_argsort_bwd_splits=mask_argsort_bwd_splits,
masks=masks,
is_subm=self.subm,
spatial_shape=spatial_shape,
out_spatial_shape=out_spatial_shape,
algo=algo,
ksize=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation)
msg = f"your indice key {self.indice_key} already exists in this sparse tensor."
assert self.indice_key not in indice_dict, msg
indice_dict[self.indice_key] = indice_data
if input.benchmark:
torch.cuda.synchronize()
t = time.time()
num_activate_out = outids.shape[0]
out_features = Fsp.implicit_gemm(
features, self.weight, pair_fwd, pair_bwd,
pair_mask_fwd_splits, pair_mask_bwd_splits,
mask_argsort_fwd_splits, mask_argsort_bwd_splits,
num_activate_out, masks, self.training, self.subm,
input._timer, self.fp32_accum)
if self.bias is not None:
out_features += self.bias
if input.benchmark:
torch.cuda.synchronize()
interval = time.time() - t
out_tensor.benchmark_record[self.name]["time"].append(interval)
out_tensor.benchmark_record[self.name]["num_points"].append(
features.shape[0])
out_tensor.benchmark_record[self.name]["num_out_points"].append(
out_features.shape[0])
if self.spss:
out_features[~mask] = input.features[~mask]
out_tensor = out_tensor.replace_feature(out_features)
else:
out_tensor = out_tensor.replace_feature(out_features)
out_tensor.indices = outids
out_tensor.indice_dict = indice_dict
out_tensor.spatial_shape = out_spatial_shape
return out_tensor
def _check_subm_reuse_valid(self, inp: SparseConvTensor,
spatial_shape: List[int],
datas: Union[ImplicitGemmIndiceData,
IndiceData]):
assert datas.is_subm, "only support reuse subm indices"
if self.kernel_size != datas.ksize:
raise ValueError(
f"subm with same indice_key must have same kernel"
f" size, expect {datas.ksize}, this layer {self.kernel_size}")
if self.dilation != datas.dilation:
raise ValueError(
f"subm with same indice_key must have same dilation"
f", expect {datas.dilation}, this layer {self.dilation}")
if inp.spatial_shape != datas.spatial_shape:
raise ValueError(
f"subm with same indice_key must have same spatial structure"
f", expect {datas.spatial_shape}, input {spatial_shape}")
if inp.indices.shape[0] != datas.indices.shape[0]:
raise ValueError(
f"subm with same indice_key must have same num of indices"
f", expect {datas.indices.shape[0]}, input {inp.indices.shape[0]}"
)
class SparseConv1d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SparseConv1d, self).__init__(1,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
class SparseConv2d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SparseConv2d, self).__init__(2,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
class SparseConv3d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SparseConv3d, self).__init__(3,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
class SparseConv4d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SparseConv4d, self).__init__(4,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
class SparseConvTranspose1d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SparseConvTranspose1d, self).__init__(1,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
transposed=True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
class SparseConvTranspose2d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SparseConvTranspose2d, self).__init__(2,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
transposed=True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
class SparseConvTranspose3d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SparseConvTranspose3d, self).__init__(3,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
transposed=True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
class SparseConvTranspose4d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SparseConvTranspose4d, self).__init__(4,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
transposed=True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
class SparseInverseConv1d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
kernel_size,
indice_key,
bias=True,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SparseInverseConv1d, self).__init__(1,
in_channels,
out_channels,
kernel_size,
bias=bias,
inverse=True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
class SparseInverseConv2d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
kernel_size,
indice_key,
bias=True,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SparseInverseConv2d, self).__init__(2,
in_channels,
out_channels,
kernel_size,
bias=bias,
inverse=True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
class SparseInverseConv3d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
kernel_size,
indice_key,
bias=True,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SparseInverseConv3d, self).__init__(3,
in_channels,
out_channels,
kernel_size,
bias=bias,
inverse=True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
class SparseInverseConv4d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
kernel_size,
indice_key,
bias=True,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SparseInverseConv4d, self).__init__(4,
in_channels,
out_channels,
kernel_size,
bias=bias,
inverse=True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
class SubMConv1d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SubMConv1d, self).__init__(1,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
class SubMConv2d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SubMConv2d, self).__init__(2,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
class SubMConv3d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SubMConv3d, self).__init__(3,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
class SubMConv4d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SubMConv4d, self).__init__(4,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
True,
indice_key=indice_key,
algo=algo,
fp32_accum=fp32_accum,
name=name)
class SPSSConv3d(SparseConvolution):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
indice_key=None,
algo: Optional[ConvAlgo] = None,
fp32_accum: Optional[bool] = None,
name=None):
super(SPSSConv3d, self).__init__(3,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,