forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ForeachTernaryOp.cu
118 lines (102 loc) · 4.49 KB
/
ForeachTernaryOp.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Dispatch.h>
#include <ATen/native/ForeachUtils.h>
#include <ATen/native/Lerp.h>
#include <ATen/native/cuda/ForeachFunctors.cuh>
#include <ATen/native/cuda/MultiTensorApply.cuh>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_foreach_lerp_native.h>
#include <ATen/ops/empty_like_native.h>
#endif
namespace at::native {
template <typename T>
struct LerpFunctor {
inline C10_DEVICE T operator()(const T self, const T end, const T weight) {
return lerp(self, end, weight);
}
};
std::vector<at::Tensor> foreach_tensor_lerp_ternary_cuda(TensorList tensors1, TensorList tensors2, TensorList tensors3) {
check_foreach_api_restrictions(tensors1, tensors2, tensors3);
if (!can_use_fast_route({tensors1, tensors2, tensors3})) {
return foreach_tensor_ternary_lerp_slow(tensors1, tensors2, tensors3);
}
std::vector<at::Tensor> vec_res;
vec_res.reserve(tensors1.size());
for (const auto& t : tensors1) {
vec_res.emplace_back(at::native::empty_like(t));
}
std::vector<std::vector<at::Tensor>> tensor_lists {tensors1.vec(), tensors2.vec(), tensors3.vec(), vec_res};
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, tensors1[0].scalar_type(), "foreach_tensor_lerp_ternary_cuda",
[&]() {
using opmath_t = typename at::opmath_type<scalar_t>;
multi_tensor_apply<4>(
tensor_lists,
TernaryOpListFunctor<scalar_t, /* depth */ 4, /* r_args_depth */ 3, /* res_arg_index */ 3>(),
LerpFunctor<opmath_t>());
}
);
return tensor_lists[3];
}
void foreach_tensor_lerp_ternary_cuda_(TensorList tensors1, TensorList tensors2, TensorList tensors3) {
check_foreach_api_restrictions(tensors1, tensors2, tensors3);
if (!can_use_fast_route({tensors1, tensors2, tensors3})) {
return foreach_tensor_ternary_lerp_slow_(tensors1, tensors2, tensors3);
}
std::vector<std::vector<at::Tensor>> tensor_lists {tensors1.vec(), tensors2.vec(), tensors3.vec()};
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, tensors1[0].scalar_type(), "foreach_tensor_lerp_ternary_cuda_",
[&]() {
using opmath_t = typename at::opmath_type<scalar_t>;
multi_tensor_apply<3>(
tensor_lists,
TernaryOpListFunctor<scalar_t, /* depth */ 3, /* r_args_depth */ 3, /* res_arg_index */ 0>(),
LerpFunctor<opmath_t>());
}
);
}
std::vector<at::Tensor> foreach_tensor_lerp_list_cuda(TensorList tensors1, TensorList tensors2, const Scalar& weight) {
check_foreach_api_restrictions(tensors1, tensors2);
if (!can_use_fast_route({tensors1, tensors2})) {
return foreach_tensor_lerp_list_kernel_slow(tensors1, tensors2, weight);
}
std::vector<at::Tensor> vec_res;
vec_res.reserve(tensors1.size());
for (const auto& t : tensors1) {
vec_res.emplace_back(at::native::empty_like(t));
}
std::vector<std::vector<at::Tensor>> tensor_lists {tensors1.vec(), tensors2.vec(), vec_res};
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, tensors1[0].scalar_type(), "foreach_tensor_lerp_scalar_cuda",
[&]() {
using opmath_t = typename at::opmath_type<scalar_t>;
multi_tensor_apply<3>(
tensor_lists,
TernaryOpScalarFunctor<scalar_t, /* depth */ 3, /* r_args_depth */ 2, /* res_arg_index */ 2>(),
LerpFunctor<opmath_t>(),
weight.to<opmath_t>());
}
);
return tensor_lists[2];
}
void foreach_tensor_lerp_list_cuda_(TensorList tensors1, TensorList tensors2, const Scalar& weight) {
check_foreach_api_restrictions(tensors1, tensors2);
if (!can_use_fast_route({tensors1, tensors2})) {
return foreach_tensor_lerp_list_kernel_slow_(tensors1, tensors2, weight);
}
std::vector<std::vector<at::Tensor>> tensor_lists {tensors1.vec(), tensors2.vec()};
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, tensors1[0].scalar_type(), "foreach_tensor_lerp_scalar_cuda_",
[&]() {
using opmath_t = typename at::opmath_type<scalar_t>;
multi_tensor_apply<2>(
tensor_lists,
TernaryOpScalarFunctor<scalar_t, /* depth */ 2, /* r_args_depth */ 2, /* res_arg_index */ 0>(),
LerpFunctor<opmath_t>(),
weight.to<opmath_t>());
}
);
}
} // namespace at::native