-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implementation of broadcast div backward by reduce #38044
Conversation
Thanks for your contribution! |
052614c
to
d3173f8
Compare
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a * b; } | ||
}; | ||
template <typename T> | ||
struct MulDxDyFunctor<paddle::platform::complex<T>> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接叫MulFunctor和DivFunctor不行吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
名字修改为MulGradFunctor和DivGradFunctor
485f716
to
c6cef2e
Compare
const paddle::platform::complex<T>& y) const { | ||
paddle::platform::complex<T> y_conj(y.real, -y.imag); | ||
return x / y_conj; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
统一DivGradFunctor形参,x和y或a和b。MulGradFunctor一样
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
ins.emplace_back(y); | ||
outs.emplace_back(&res_dy); | ||
|
||
const auto& cuda_ctx = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cuda_ctx统一改成dev_ctx
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
// x * y / z | ||
template <typename T> | ||
struct MulDivGradFunctor { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改成DivGradYFunctor吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
20aa561
to
22f4434
Compare
22f4434
to
9265a8d
Compare
|
||
std::vector<int> reduce_dims = GetReduceDim(x->dims(), out->dims(), axis); | ||
gpuStream_t stream = ctx.cuda_device_context().stream(); | ||
TensorReduceFunctorImpl<T, T, CustomSum>(res_dx, dx, reduce_dims, stream); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reduce接口有变化,需要响应修改,参考:#38135
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
std::vector<int> reduce_dims = GetReduceDim(y->dims(), out->dims(), axis); | ||
gpuStream_t stream = ctx.cuda_device_context().stream(); | ||
TensorReduceFunctorImpl<T, T, CustomSub>(res_dy, dy, reduce_dims, stream); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reduce接口需要更新
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
ins.emplace_back(dout); | ||
ins.emplace_back(out); | ||
ins.emplace_back(y); | ||
outs.emplace_back(&res_dy); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
创建vector时,就可以初始化了,没必要emplace_back()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
const auto& dev_ctx = | ||
ctx.template device_context<platform::CUDADeviceContext>(); | ||
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DivGradYFunctor既是三元的,这里的kBinary->kTernary更合适吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
f87757d
to
be21b4b
Compare
be21b4b
to
c3780ac
Compare
c3780ac
to
f0f1cf3
Compare
framework::Tensor tmp_dx; | ||
tmp_dx.Resize(dout->dims()); | ||
|
||
ElementwiseComputeEx<DivGradFunctor<T>, DeviceContext, T>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
纯GPU代码就不要调用这个接口了,这个接口是用于同时需要支持CPU 和 GPU计算的时候才用的,纯粹GPU的代码还是走LaunchElementwiseCudaKernel
更直观
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
} | ||
if (dx->dims() == dout->dims()) { | ||
// dx = dout/y | ||
ElementwiseComputeEx<DivGradFunctor<T>, DeviceContext, T>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
dx[col] = o / y_conj; | ||
if (dx != nullptr) { | ||
dx[col] = o / y_conj; | ||
} | ||
dy[col] = -o * out_div_y_conj; | ||
col += blockDim.x * gridDim.x; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这种写法可以修改成为 grid_stride
的写法,见链接:https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
dx[col] = o / y_conj; | ||
if (dx != nullptr) { | ||
dx[col] = o / y_conj; | ||
} | ||
dy[col] = -o * out_div_y_conj; | ||
col += blockDim.x * gridDim.x; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
1c91d6b
to
e07e54e
Compare
} | ||
|
||
template <typename T> | ||
void reduce_functor(const framework::ExecutionContext& ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
函数名都改成大驼峰吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
template <typename DeviceContext, typename T> | ||
typename std::enable_if< | ||
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type | ||
default_elementwise_div_grad(const framework::ExecutionContext& ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
template <typename T> | ||
void reduce_functor(const framework::ExecutionContext& ctx, | ||
const framework::Tensor* in, const framework::Tensor* out, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in和src,out和dst,这些变量名有啥区别,各自作用都是啥呢?能不能区分或者说明一下?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in,out用于计算reduce_dims,src表示需要reduce的值,dst表示reduce计算后的值。可以添加注释说明
} | ||
|
||
template <typename T> | ||
void reduce_functor(const framework::ExecutionContext& ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可以直接传CUDA device
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
dx[col] = o / y_conj; | ||
dy[col] = -o * out_div_y_conj; | ||
col += blockDim.x * gridDim.x; | ||
if (dx->dims() == dout->dims() && dy->dims() == dout->dims()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接调两个reduce_functor
就可以,不需要这个if else了
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>( | ||
dev_ctx, ins, &outs, axis, DivGradFunctor<T>()); | ||
if (dx->dims() != dout->dims()) { | ||
reduce_functor<T>(ctx, x, out, &tmp_dx, dx); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
std::vector<framework::Tensor*> outs = {&tmp_dy}; | ||
LaunchElementwiseCudaKernel<ElementwiseType::kTernary, T, T>( | ||
dev_ctx, ins, &outs, axis, DivGradYFunctor<T>()); | ||
if (dy->dims() != dout->dims()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivGradDY<T>>( | ||
ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX<T>(), | ||
DivGradDY<T>()); | ||
default_elementwise_div_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default也改个名字吧,比如改成Common,或者其他更好的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
后续会统一修改
dy[col] = -o * out_div_y_conj; | ||
col += blockDim.x * gridDim.x; | ||
if (dx->dims() == dout->dims() && dy->dims() == dout->dims()) { | ||
dx->ShareDataWith(tmp_dx); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ShareDataWith
这种写法把tensor tmp_dx 赋给了 dx ,对模型运行时可能会造成问题,尽量避免掉这种写法
|
||
// Complex div grad | ||
template <typename T> | ||
struct DivGradFunctor<Complex<T>> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是不是跟GradY对应起来写成GradX
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
std::vector<const framework::Tensor*> ins = {dout, out, y}; | ||
std::vector<framework::Tensor*> outs; | ||
if (dx->dims() == dout->dims() && dy->dims() == dout->dims()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个分支可以删掉,因为这种情况下根本不会进到这个接口
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删掉了原来相同dims 的CUDA函数,相同dims走该分支,所以保留
framework::Tensor* dy) { | ||
int axis = ctx.Attr<int>("axis"); | ||
auto* dout_data = dout->data<T>(); | ||
dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
block_size 定义了但没有被使用
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经删掉
} | ||
|
||
template <typename T> | ||
void ReduceForDiv(const platform::CUDADeviceContext& dev_ctx, int axis, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个封装不止适用于div吧,是不是能改成ReduceWrapper()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经提出公共接口函数
@@ -146,14 +167,11 @@ class ElementwiseDivGradKernel : public ElemwiseGradKernel<T> { | |||
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); | |||
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); | |||
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); | |||
int axis = ctx.Attr<int>("axis"); | |||
|
|||
if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DefaultElementwiseDivGrad已经包括这个分支了,可以删除
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
std::vector<const framework::Tensor*> ins = {dout, out, y}; | ||
std::vector<framework::Tensor*> outs; | ||
if (dx->dims() == dout->dims() && dy->dims() == dout->dims()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个分支重复,可以去掉,当(dx->dims() == dy->dims())时,外层调用不会执行到该函数。
或者,保留,将外层调用的 dx != nullptr && dy != nullptr && (dx->dims() == dy->dims()) 分支删掉。见下
|
||
template <typename InT, typename OutT> | ||
struct DivGradXYFunctor { | ||
inline HOSTDEVICE paddle::framework::Array<OutT, 2> operator()(InT a, InT b, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
只读参数传 const reference,下同
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
framework::Tensor tmp_dy; | ||
tmp_dy.mutable_data<T>(dout->dims(), ctx.GetPlace()); | ||
if (dx != nullptr && dy != nullptr) { | ||
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mutable_data的结果不必传给指针(下文没用到指针),下同
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
dx[col] = o / y[col]; | ||
dy[col] = -o * out[col] / y[col]; | ||
col += blockDim.x * gridDim.x; | ||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数还有必要吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
该函数目前已经删掉,走多输出分支
framework::Tensor tmp_dx; | ||
tmp_dx.mutable_data<T>(dout->dims(), ctx.GetPlace()); | ||
framework::Tensor tmp_dy; | ||
tmp_dy.mutable_data<T>(dout->dims(), ctx.GetPlace()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
并不是所有情况都需要使用临时Tensor吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
申请临时Tensor空间放在了if else分支中
ReduceWrapper<T>(dev_ctx, axis, &tmp_dy, dy); | ||
} | ||
} | ||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GetGradXOut
和 GetGradYOut
两个函数太像了,可以压缩成一个函数
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
if (dx->dims() == dout->dims() && dy->dims() == dout->dims()) { | ||
outs = {dx, dy}; | ||
} | ||
if (dx->dims() != dout->dims() && dy->dims() == dout->dims()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里应该是 else if
吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
const auto& dev_ctx = | ||
ctx.template device_context<platform::CUDADeviceContext>(); | ||
if (dx != nullptr && dy != nullptr) { | ||
GetGradXYOut<T>(dev_ctx, axis, x, y, out, dout, dx, dy, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉可以封装成一个,用不到的参数传空指针就好,剩下的只有functor不同了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
传入空指针的话,在GetGradXYOut函数中需要多次判断指针是否为空,代码可读性不太高
@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and | |||
limitations under the License. */ | |||
|
|||
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h" | |||
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" | |||
#include "paddle/fluid/platform/complex.h" | |||
#include "paddle/fluid/platform/float16.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
头文件已经删除
const auto& dev_ctx = | ||
ctx.template device_context<platform::CUDADeviceContext>(); | ||
if (dx != nullptr && dy != nullptr) { | ||
GetGradXYOut<T>(dev_ctx, axis, x, y, out, dout, dx, dy, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
传入空指针的话,在GetGradXYOut函数中需要多次判断指针是否为空,代码可读性不太高
if (dx->dims() == dout->dims() && dy->dims() == dout->dims()) { | ||
outs = {dx, dy}; | ||
} | ||
if (dx->dims() != dout->dims() && dy->dims() == dout->dims()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
ReduceWrapper<T>(dev_ctx, axis, &tmp_dy, dy); | ||
} | ||
} | ||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with this pr yet, but there still are some modification advices needed to be discussed with other pr reviewers. May you discuss with other reviewers as soon as possible. And please not use force-push once reviewing process inits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Performance optimization
PR changes
OPs
Describe