-
Notifications
You must be signed in to change notification settings - Fork 130
/
index_add_cuda_pytorch_impl.cu
582 lines (504 loc) · 21.2 KB
/
index_add_cuda_pytorch_impl.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <cstdint>
#include "oneflow/core/common/bfloat16.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/common/shape.h"
#include "oneflow/core/common/shape_vec.h"
#include "oneflow/core/cuda/atomic.cuh"
#include "oneflow/core/common/util.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/framework/user_op_hob.h"
#include "oneflow/core/kernel/new_kernel_util.h"
#include "oneflow/core/kernel/kernel_util.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
namespace oneflow {
namespace {
/*
[collapse dims] Updates sizes, and strides to reflect a "collapse" of
the info, possibly excluding the optional excludeDim. A "collapsed" version
of the info is the fewest dims that order the tensor's elements in the same
way as the original info. If excludeDim is specified, the collapse is the
fewest dims that order the tensor's elements as the original and preserve the
excluded dimension, unless the tensor collapses to a point.
This function returns a pair of values.
1) The (new) index of the preserved dimension if excludeDim is
specified. 0 if the tensor is collapsed to a point. -1
otherwise.
2) The new number of dimensions.
*/
template <typename T>
inline std::pair<int64_t, int64_t> collapse_dims(
T* sizes,
T* strides,
int64_t dims,
const int excludeDim = -1) {
CHECK_EQ(
excludeDim >= -1 && excludeDim < dims, true) <<
"expected excluded dim between -1 and dims - 1";
int64_t stopDim = (excludeDim == -1) ? dims : excludeDim;
int64_t newIndex = -1;
int64_t oldIndex = 0;
int64_t remappedExcludedDim = -1;
while (oldIndex < dims) {
// Finds a dimension to collapse into
for (; oldIndex < stopDim; ++oldIndex) {
if (sizes[oldIndex] == 1) {
continue;
}
++newIndex;
sizes[newIndex] = sizes[oldIndex];
strides[newIndex] = strides[oldIndex];
++oldIndex;
break;
}
// Collapses dims
for (; oldIndex < stopDim; ++oldIndex) {
if (sizes[oldIndex] == 1) {
continue;
}
if (strides[newIndex] == sizes[oldIndex] * strides[oldIndex]) {
sizes[newIndex] *= sizes[oldIndex];
strides[newIndex] = strides[oldIndex];
} else {
++newIndex;
sizes[newIndex] = sizes[oldIndex];
strides[newIndex] = strides[oldIndex];
}
}
// Handles excludeDim being set (oldIndex == excludeDim)
if (oldIndex != dims) {
// Preserves excluded dimension
++newIndex;
sizes[newIndex] = sizes[oldIndex];
strides[newIndex] = strides[oldIndex];
remappedExcludedDim = newIndex;
// Restarts iteration after excludeDim
++oldIndex;
stopDim = dims;
}
}
// Handles special case of all dims size 1
if (newIndex == -1 || (newIndex == 0 && sizes[0] == 1)) {
dims = 1;
sizes[0] = 1;
strides[0] = 1;
return std::pair<int64_t, int64_t>(0, 1);
}
dims = newIndex + 1;
return std::pair<int64_t, int64_t>(remappedExcludedDim, dims);
}
bool IsContiguous(size_t num_dims, const ShapeView& t, const Stride& stride) {
DimVector t_shape_dim;
t.ToDimVector(&t_shape_dim);
std::vector<int32_t> t_stride(stride.begin(), stride.end());
for (int i = num_dims - 1; i >= 0; i--) {
if ((i == num_dims - 1 && t_stride[i] != 1)
|| (i != num_dims - 1 && t_stride[i] != t_shape_dim[i + 1] * t_stride[i + 1])) {
return false;
}
}
return true;
}
// CUDA kernel argument that defines tensor layout
template <typename IndexType>
struct TensorInfo {
TensorInfo();
TensorInfo(int dim,
IndexType sz[SHAPE_MAX_AXIS_SIZE],
IndexType st[SHAPE_MAX_AXIS_SIZE]);
// Set the size of the given dimension to 1, as if it were a
// reduction dim (allows you to calculate offsets of the reduction
// slice)
void reduceDim(int dim);
// See note on [collapse dims].
int collapseDims(const int excludeDim = -1);
// Contiguous tensors of more than one dimension are collapsed down
// to one tensor
OF_DEVICE_FUNCTION bool isContiguous() const {
return (dims == 1 && strides[0] == 1);
}
IndexType sizes[SHAPE_MAX_AXIS_SIZE];
IndexType strides[SHAPE_MAX_AXIS_SIZE];
int dims;
};
template <typename IndexType>
TensorInfo<IndexType>::TensorInfo() {
dims = 0;
}
template <typename IndexType>
TensorInfo<IndexType>::TensorInfo(int dim,
IndexType sz[SHAPE_MAX_AXIS_SIZE],
IndexType st[SHAPE_MAX_AXIS_SIZE]) {
dims = dim;
CHECK_EQ(dims < SHAPE_MAX_AXIS_SIZE, true) << "CUDA Tensors cannot have more than 25 dimensions";
for (int i = 0; i < dim; ++i) {
sizes[i] = sz[i];
strides[i] = st[i];
}
}
template <typename IndexType>
void
TensorInfo<IndexType>::reduceDim(int dim) {
CHECK_EQ(dim < dims && dim >= 0, true) << "expected dim between 0 and dims - 1";
sizes[dim] = 1;
}
template <typename IndexType>
int
TensorInfo<IndexType>::collapseDims(const int excludeDim) {
auto result = collapse_dims(sizes, strides, dims, excludeDim);
dims = std::get<1>(result);
return std::get<0>(result);
}
// Translate a linear index for the apply to a T* offset;
// specialized on `Dims` to reduce nvcc compilation time
template <typename IndexType, int Dims>
struct IndexToOffset {
static OF_DEVICE_FUNCTION IndexType get(
IndexType linearId,
const TensorInfo<IndexType>& info) {
IndexType offset = 0;
// Uses static dims
for (int i = Dims - 1; i > 0; --i) {
IndexType curDimIndex = linearId - linearId / info.sizes[i];
IndexType curDimOffset = curDimIndex * info.strides[i];
offset += curDimOffset;
linearId /= info.sizes[i];
}
return offset + linearId * info.strides[0];
}
};
// Uses dynamic (runtime) instead of static (compiletime) dims
template <typename IndexType>
struct IndexToOffset<IndexType, -1> {
static OF_DEVICE_FUNCTION IndexType get(
IndexType linearId,
const TensorInfo<IndexType>& info) {
IndexType offset = 0;
for (int i = info.dims - 1; i > 0; --i) {
IndexType curDimIndex = linearId - linearId / info.sizes[i];
IndexType curDimOffset = curDimIndex * info.strides[i];
offset += curDimOffset;
linearId /= info.sizes[i];
}
return offset + linearId * info.strides[0];
}
};
template <typename IndexType>
TensorInfo<IndexType>
getTensorInfo(const ShapeView &t, const Stride& stride) {
IndexType sz[SHAPE_MAX_AXIS_SIZE];
IndexType st[SHAPE_MAX_AXIS_SIZE];
DimVector t_shape_dim;
t.ToDimVector(&t_shape_dim);
std::vector<int32_t> t_stride(stride.begin(), stride.end());
int dims = t_shape_dim.size();
for (int i = 0; i < dims; ++i) {
sz[i] = t_shape_dim[i];
st[i] = t_stride[i];
}
return TensorInfo<IndexType>(dims, sz, st);
}
/**
Computes ceil(a / b)
*/
template <typename T, typename = std::enable_if_t<std::is_integral<T>::value>>
OF_DEVICE_FUNCTION T ceil_div(T a, T b) {
return (a + b - 1) / b;
}
bool canUse32BitIndexMath(const ShapeView& t, const Stride& stride, int32_t max_elem=std::numeric_limits<int32_t>::max()) {
auto elements = t.elem_cnt();
if (elements >= max_elem) {
return false;
}
if (elements == 0) {
return max_elem > 0;
}
size_t offset = 0;
auto linearId = elements - 1;
DimVector t_shape_dim;
t.ToDimVector(&t_shape_dim);
std::vector<int32_t> t_stride(stride.begin(), stride.end());
// NOTE: Assumes all strides are positive, which is true for now
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
for (int i = t_shape_dim.size() - 1; i >= 0; --i) {
auto curDimIndex = linearId % t_shape_dim[i];
auto curDimOffset = curDimIndex * t_stride[i];
offset += curDimOffset;
linearId /= t_shape_dim[i];
}
if (offset >= max_elem) {
return false;
}
return true;
}
class ReduceAdd {
public:
template <typename T>
constexpr __device__ void operator() (T* self_data_start, int32_t index, int32_t numel, const T * src_data) const {
cuda::atomic::FastAdd(self_data_start, index, numel, *src_data);
}
};
static ReduceAdd reduce_add;
// We prefer this kernel to avoid reloading index points if the number
// of indices is a small number.
// This kernel in fact works for all choices of problem size, but if
// the number of indices chosen is large, then the
// indexFuncLargeIndex kernel is a better choice to increase
// parallelism.
template <typename T, typename IndexType, int DstDim, int SrcDim, int IdxDim,
typename func_t>
__global__ void indexFuncSmallIndex(TensorInfo<IndexType> dst,
TensorInfo<IndexType> src,
TensorInfo<IndexType> indices,
const IndexType* indices_data,
const T* src_data,
T* dst_data,
int32_t dstAddDim,
int32_t srcAddDim,
IndexType innerSize,
int32_t dstAddDimSize,
int32_t dstNumel,
const func_t& op,
T alpha) {
// In order to avoid reloading the index that we are copying, load
// it once to handle all of the points that are being selected, so
// it can be reused as much as possible. This kernel is chosen when
// this is a good choice (small number of chosen indices), since
// re-accessing indices in addition to src elements can be slow.
for (IndexType srcIndex = 0; srcIndex < indices.sizes[0]; ++srcIndex) {
// Lua indices begin at 1
IndexType dstIndex =
indices_data[IndexToOffset<IndexType, IdxDim>::get(srcIndex, indices)];
assert(dstIndex < dstAddDimSize);
// We stride over the output ignoring the indexed dimension
// (innerSize), whose offset calculation is handled differently
for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
linearIndex < innerSize;
linearIndex += gridDim.x * blockDim.x) {
IndexType dstOffset =
IndexToOffset<IndexType, DstDim>::get(linearIndex, dst);
dstOffset += dstIndex * dst.strides[dstAddDim];
IndexType srcOffset =
IndexToOffset<IndexType, SrcDim>::get(linearIndex, src);
srcOffset += srcIndex * src.strides[srcAddDim];
T val = src_data[srcOffset] * alpha;
op(dst_data, dstOffset, dstNumel, &val);
}
}
}
// We prefer this kernel to balance parallelism across index points,
// if there are a large number of indices.
// This kernel in fact works for all choices of problem size, but if
// the number of indices chosen is small, then the
// indexFuncSmallIndex kernel is a better choice to reduce memory
// accesses.
template <typename T, typename IndexType, int DstDim, int SrcDim, int IdxDim,
bool IndexIsMajor, typename func_t>
__global__ void indexFuncLargeIndex(TensorInfo<IndexType> dst,
TensorInfo<IndexType> src,
TensorInfo<IndexType> indices,
const IndexType* indices_data,
const T* src_data,
T* dst_data,
int32_t dstAddDim,
int32_t srcAddDim,
IndexType totalSize,
IndexType innerSize,
int32_t dstAddDimSize,
int32_t dstNumel,
const func_t& op,
T alpha) {
// We stride over the output including the indexed dimension
// (totalSize), and calculate the destination index point based on that
for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
linearIndex < totalSize;
linearIndex += gridDim.x * blockDim.x) {
IndexType srcIndex, elementInSlice;
if (IndexIsMajor) {
srcIndex = linearIndex / innerSize;
elementInSlice = linearIndex - linearIndex / innerSize;
}
else {
elementInSlice = linearIndex / innerSize;
srcIndex = linearIndex - linearIndex / innerSize;
}
// Lua indices begin at 1
IndexType dstIndex =
indices_data[IndexToOffset<IndexType, IdxDim>::get(srcIndex, indices)];
assert(dstIndex < dstAddDimSize);
IndexType dstOffset =
IndexToOffset<IndexType, DstDim>::get(elementInSlice, dst);
dstOffset += dstIndex * dst.strides[dstAddDim];
IndexType srcOffset =
IndexToOffset<IndexType, SrcDim>::get(elementInSlice, src);
srcOffset += srcIndex * src.strides[srcAddDim];
T val = src_data[srcOffset] * alpha;
op(dst_data, dstOffset, dstNumel, &val);
}
}
// Compare the stride between adjacent slices (sliceStride) with strides in the
// other dimensions (i.e., strides *inside* each slice).
//
// - Returns true if some dimension inside the slice has lower stride than
// sliceStride. The simplest example is a 2-D contiguous tensor with sliceDim
// == 0 (that is, each slice is a row).
//
// In this case, we choose the CUDA kernel that processes the data in
// "index-major order". For example, if thread count equals slice size, then
// all threads process slice #0 in lockstep, and then slice #1, and so on.
//
// - Otherwise (i.e., sliceStride has the lowest value), this function returns
// false. The simplest example is a 2-D contiguous tensor with sliceDim == 1
// (each slice is a column).
//
// In this case, we choose the CUDA kernel that processes the data in
// "elementInSlice-major order". For example, each thread can process element
// #0 of every slice, and then element #1 of every slice, and so on.
template <typename IndexT>
bool indexShouldBeMajor(TensorInfo<IndexT> &info,
int sliceDim)
{
// The stride between adjacent slices (e.g., between element #0 of slice #100
// and element #0 of slice #101).
unsigned int sliceStride = info.strides[sliceDim];
for (size_t i = 0; i < info.dims; i++) {
if (i != sliceDim && info.sizes[i] > 1 && info.strides[i] < sliceStride) {
return true;
}
}
return false;
}
}; // namespace
template<typename T, typename IndexT>
class IndexAddGpuKernel final : public user_op::OpKernel {
public:
IndexAddGpuKernel() = default;
~IndexAddGpuKernel() = default;
private:
using user_op::OpKernel::Compute;
void Compute(user_op::KernelComputeContext* ctx) const override {
const user_op::Tensor* self = ctx->Tensor4ArgNameAndIndex("input", 0);
const user_op::Tensor* index = ctx->Tensor4ArgNameAndIndex("index", 0);
const user_op::Tensor* source = ctx->Tensor4ArgNameAndIndex("source", 0);
user_op::Tensor* output = ctx->Tensor4ArgNameAndIndex("output", 0);
const int32_t dim = ctx->Attr<int32_t>("dim");
const float alpha = ctx->Attr<float>("alpha");
const ShapeView& self_shape = self->shape_view();
const ShapeView& source_shape = source->shape_view();
const ShapeView& index_shape = index->shape_view();
DimVector self_shape_dim, source_shape_dim, index_shape_dim;
self_shape.ToDimVector(&self_shape_dim);
source_shape.ToDimVector(&source_shape_dim);
index_shape.ToDimVector(&index_shape_dim);
const Stride& self_stride = self->stride();
const Stride& index_stride = index->stride();
const Stride& source_stride = source->stride();
Memcpy<DeviceType::kCUDA>(
ctx->stream(), output->mut_dptr<void>(), self->dptr<void>(),
self->shape_view().elem_cnt() * GetSizeOfDataType(self->data_type()));
int32_t sliceSize = 1;
for (int i = 0; i < self_shape_dim.size(); i++){
if (i != dim){
sliceSize *= self_shape_dim[i];
}
}
const int32_t sourceTotalSize = source_shape.elem_cnt();
const int32_t selfAddDimSize = self_shape_dim[dim];
const int32_t numIndex = index_shape.elem_cnt();
const int32_t selfNumel = self_shape.elem_cnt();
const cudaStream_t stream = ctx->stream()->As<ep::CudaStream>()->cuda_stream();
#define SMALL_INDEX(TENSOR_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM) \
indexFuncSmallIndex<TENSOR_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM> \
<<<smallIndexGrid, smallIndexBlock, 0, stream>>>( \
selfInfo, sourceInfo, indexInfo, index->dptr<TYPE>(), source->dptr<T>(), output->mut_dptr<T>(), \
selfAddDim, sourceAddDim, sliceSize, selfAddDimSize, \
selfNumel, reduce_add, alpha_value); \
#define LARGE_INDEX(TENSOR_TYPE, TYPE, \
SELF_DIM, SOURCE_DIM, IDX_DIM, IDX_IS_MAJOR) \
indexFuncLargeIndex<TENSOR_TYPE, TYPE, \
SELF_DIM, SOURCE_DIM, IDX_DIM, IDX_IS_MAJOR> \
<<<largeIndexGrid, largeIndexBlock, 0, stream>>>( \
selfInfo, sourceInfo, indexInfo, index->dptr<TYPE>(), source->dptr<T>(), output->mut_dptr<T>(), \
selfAddDim, sourceAddDim, sourceTotalSize, \
(IDX_IS_MAJOR) ? sliceSize : numIndex, \
selfAddDimSize, selfNumel, reduce_add, alpha_value); \
const bool indContig = IsContiguous(index_shape_dim.size(), index_shape, index_stride);
const int mpc = static_cast<uint32_t>(ctx->stream()->As<ep::CudaStream>()->device_properties().multiProcessorCount);
const dim3 smallIndexGrid(std::min(ceil_div(sliceSize, (int32_t)128), (int32_t)(mpc * 8)));
const dim3 smallIndexBlock(std::min(sliceSize, (int32_t)128));
const dim3 largeIndexGrid(std::min(ceil_div(sourceTotalSize, (int32_t)128), (int32_t)(mpc * 8)));
const dim3 largeIndexBlock(std::min(sourceTotalSize, (int32_t)128));
const T alpha_value = static_cast<T>(alpha);
TensorInfo<IndexT> selfInfo = getTensorInfo<IndexT>(self_shape, self_stride);
// const int32_t selfAddDim = dim;
const int selfAddDim = selfInfo.collapseDims(dim);
selfInfo.reduceDim(dim);
TensorInfo<IndexT> sourceInfo = getTensorInfo<IndexT>(source_shape, source_stride);
// const int32_t sourceAddDim = dim;
const int sourceAddDim = sourceInfo.collapseDims(dim);
sourceInfo.reduceDim(dim);
TensorInfo<IndexT> indexInfo = getTensorInfo<IndexT>(index_shape, index_stride);
indexInfo.collapseDims();
if (canUse32BitIndexMath(self_shape, self_stride) && canUse32BitIndexMath(source_shape, source_stride) && canUse32BitIndexMath(index_shape, index_stride)) {
if(numIndex <= 16){
if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) {
SMALL_INDEX(T, IndexT, 1, 1, -2);
} else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) {
SMALL_INDEX(T, IndexT, 2, 2, -2);
} else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) {
SMALL_INDEX(T, IndexT, 3, 3, -2);
} else {
SMALL_INDEX(T, IndexT, -1, -1, -1);
}
}
else {
const bool indexIsMajor = indexShouldBeMajor<IndexT>(selfInfo, selfAddDim);
if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) {
LARGE_INDEX(T, IndexT, 1, 1, -2, true);
} else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) {
if (indexIsMajor) {
LARGE_INDEX(T, IndexT, 2, 2, -2, true);
} else {
LARGE_INDEX(T, IndexT, 2, 2, -2, false);
}
} else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) {
if (indexIsMajor) {
LARGE_INDEX(T, IndexT, 3, 3, -2, true);
} else {
LARGE_INDEX(T, IndexT, 3, 3, -2, false);
}
} else {
LARGE_INDEX(T, IndexT, -1, -1, -1, true);
}
}
}
else{
LARGE_INDEX(T, IndexT, -1, -1, -1, true);
}
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
#define REGISTER_INDEX_ADD_CUDA_KERNEL(dtype, index_dtype) \
REGISTER_USER_KERNEL("index_add") \
.SetCreateFn<IndexAddGpuKernel<dtype, index_dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \
&& (user_op::HobDataType("output", 0) == GetDataType<dtype>::value) \
&& (user_op::HobDataType("index", 0) == GetDataType<index_dtype>::value));
REGISTER_INDEX_ADD_CUDA_KERNEL(float, int32_t)
REGISTER_INDEX_ADD_CUDA_KERNEL(float, int64_t)
REGISTER_INDEX_ADD_CUDA_KERNEL(half, int32_t)
REGISTER_INDEX_ADD_CUDA_KERNEL(half, int64_t)
REGISTER_INDEX_ADD_CUDA_KERNEL(double, int32_t)
REGISTER_INDEX_ADD_CUDA_KERNEL(double, int64_t)
} // namespace oneflow