Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implementation of broadcast div backward by reduce #38044

Merged
merged 15 commits into from
Jan 5, 2022

Conversation

Zjq9409
Copy link
Contributor

@Zjq9409 Zjq9409 commented Dec 10, 2021

PR types

Performance optimization

PR changes

OPs

Describe

case pytorch 优化前 优化前相比pytorch 优化后 优化后相比pytorch 加速比
[50, 128, 1000], [128, 1000] 0.46865 0.24259 优于 (48.24%) 0.23764 优于 (49.29%) 1.02
[50, 128, 1000], [1, 128, 1000] 0.46940 0.24346 优于 (48.13%) 0.23795 优于 (49.30%) 1.02
[16, 2048, 7, 7], [16, 2048] 0.14044 0.07819 优于 (44.32%) 0.07565 优于 (45.84%) 1.03
[16, 2048, 16, 16], [16, 2048, 16, 16] 0.71575 0.34497 优于 (1.07x) 0.34354 优于 (1.07x) 1.00
[16,1,513,513], [1] 0.31762 4.67214 差于 (13.71x) 0.15971 优于 (49.51%) 29.25
[512, 896, 4, 12], [512, 896, 4, 1] 1.68353 2.82219 差于 (67.64%) 0.86215 优于 (48.78%) 3.27
[512, 896, 4, 12], [512, 896, 4, 1] fp16 1.17390 2.74304 差于 (1.34x) 0.60514 优于 (48.67%) 4.53
[32, 12, 128, 128], [32, 1, 1, 128] fp16 0.34941 0.57034 差于 (63.23%) 0.15004 优于 (1.32x) 3.80
[32, 1, 1, 128], [1, 12, 128, 1] fp16 0.38124 0.4983 差于 (30.71%) 0.19352 优于 (49.29%) 2.57

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a * b; }
};
template <typename T>
struct MulDxDyFunctor<paddle::platform::complex<T>> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接叫MulFunctor和DivFunctor不行吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

名字修改为MulGradFunctor和DivGradFunctor

@Zjq9409 Zjq9409 changed the title add elementwise div implementation of broadcast div backward by reduce Dec 13, 2021
const paddle::platform::complex<T>& y) const {
paddle::platform::complex<T> y_conj(y.real, -y.imag);
return x / y_conj;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

统一DivGradFunctor形参,x和y或a和b。MulGradFunctor一样

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

ins.emplace_back(y);
outs.emplace_back(&res_dy);

const auto& cuda_ctx =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuda_ctx统一改成dev_ctx

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


// x * y / z
template <typename T>
struct MulDivGradFunctor {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成DivGradYFunctor吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


std::vector<int> reduce_dims = GetReduceDim(x->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, CustomSum>(res_dx, dx, reduce_dims, stream);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reduce接口有变化,需要响应修改,参考:#38135

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


std::vector<int> reduce_dims = GetReduceDim(y->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, CustomSub>(res_dy, dy, reduce_dims, stream);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reduce接口需要更新

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

ins.emplace_back(dout);
ins.emplace_back(out);
ins.emplace_back(y);
outs.emplace_back(&res_dy);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

创建vector时,就可以初始化了,没必要emplace_back()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


const auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DivGradYFunctor既是三元的,这里的kBinary->kTernary更合适吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

framework::Tensor tmp_dx;
tmp_dx.Resize(dout->dims());

ElementwiseComputeEx<DivGradFunctor<T>, DeviceContext, T>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

纯GPU代码就不要调用这个接口了,这个接口是用于同时需要支持CPU 和 GPU计算的时候才用的,纯粹GPU的代码还是走LaunchElementwiseCudaKernel 更直观

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}
if (dx->dims() == dout->dims()) {
// dx = dout/y
ElementwiseComputeEx<DivGradFunctor<T>, DeviceContext, T>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

dx[col] = o / y_conj;
if (dx != nullptr) {
dx[col] = o / y_conj;
}
dy[col] = -o * out_div_y_conj;
col += blockDim.x * gridDim.x;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种写法可以修改成为 grid_stride的写法,见链接:https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

dx[col] = o / y_conj;
if (dx != nullptr) {
dx[col] = o / y_conj;
}
dy[col] = -o * out_div_y_conj;
col += blockDim.x * gridDim.x;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}

template <typename T>
void reduce_functor(const framework::ExecutionContext& ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函数名都改成大驼峰吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
default_elementwise_div_grad(const framework::ExecutionContext& ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


template <typename T>
void reduce_functor(const framework::ExecutionContext& ctx,
const framework::Tensor* in, const framework::Tensor* out,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in和src,out和dst,这些变量名有啥区别,各自作用都是啥呢?能不能区分或者说明一下?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in,out用于计算reduce_dims,src表示需要reduce的值,dst表示reduce计算后的值。可以添加注释说明

}

template <typename T>
void reduce_functor(const framework::ExecutionContext& ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以直接传CUDA device

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

dx[col] = o / y_conj;
dy[col] = -o * out_div_y_conj;
col += blockDim.x * gridDim.x;
if (dx->dims() == dout->dims() && dy->dims() == dout->dims()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接调两个reduce_functor就可以,不需要这个if else了

LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, axis, DivGradFunctor<T>());
if (dx->dims() != dout->dims()) {
reduce_functor<T>(ctx, x, out, &tmp_dx, dx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

std::vector<framework::Tensor*> outs = {&tmp_dy};
LaunchElementwiseCudaKernel<ElementwiseType::kTernary, T, T>(
dev_ctx, ins, &outs, axis, DivGradYFunctor<T>());
if (dy->dims() != dout->dims()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivGradDY<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX<T>(),
DivGradDY<T>());
default_elementwise_div_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default也改个名字吧,比如改成Common,或者其他更好的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续会统一修改

dy[col] = -o * out_div_y_conj;
col += blockDim.x * gridDim.x;
if (dx->dims() == dout->dims() && dy->dims() == dout->dims()) {
dx->ShareDataWith(tmp_dx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ShareDataWith 这种写法把tensor tmp_dx 赋给了 dx ,对模型运行时可能会造成问题,尽量避免掉这种写法


// Complex div grad
template <typename T>
struct DivGradFunctor<Complex<T>> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不是跟GradY对应起来写成GradX

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


std::vector<const framework::Tensor*> ins = {dout, out, y};
std::vector<framework::Tensor*> outs;
if (dx->dims() == dout->dims() && dy->dims() == dout->dims()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个分支可以删掉,因为这种情况下根本不会进到这个接口

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删掉了原来相同dims 的CUDA函数,相同dims走该分支,所以保留

framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
auto* dout_data = dout->data<T>();
dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

block_size 定义了但没有被使用

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经删掉

}

template <typename T>
void ReduceForDiv(const platform::CUDADeviceContext& dev_ctx, int axis,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个封装不止适用于div吧,是不是能改成ReduceWrapper()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经提出公共接口函数

@@ -146,14 +167,11 @@ class ElementwiseDivGradKernel : public ElemwiseGradKernel<T> {
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");

if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DefaultElementwiseDivGrad已经包括这个分支了,可以删除

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


std::vector<const framework::Tensor*> ins = {dout, out, y};
std::vector<framework::Tensor*> outs;
if (dx->dims() == dout->dims() && dy->dims() == dout->dims()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个分支重复,可以去掉,当(dx->dims() == dy->dims())时,外层调用不会执行到该函数。
或者,保留,将外层调用的 dx != nullptr && dy != nullptr && (dx->dims() == dy->dims()) 分支删掉。见下


template <typename InT, typename OutT>
struct DivGradXYFunctor {
inline HOSTDEVICE paddle::framework::Array<OutT, 2> operator()(InT a, InT b,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

只读参数传 const reference,下同

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

framework::Tensor tmp_dy;
tmp_dy.mutable_data<T>(dout->dims(), ctx.GetPlace());
if (dx != nullptr && dy != nullptr) {
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mutable_data的结果不必传给指针(下文没用到指针),下同

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

dx[col] = o / y[col];
dy[col] = -o * out[col] / y[col];
col += blockDim.x * gridDim.x;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数还有必要吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该函数目前已经删掉,走多输出分支

framework::Tensor tmp_dx;
tmp_dx.mutable_data<T>(dout->dims(), ctx.GetPlace());
framework::Tensor tmp_dy;
tmp_dy.mutable_data<T>(dout->dims(), ctx.GetPlace());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

并不是所有情况都需要使用临时Tensor吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

申请临时Tensor空间放在了if else分支中

ReduceWrapper<T>(dev_ctx, axis, &tmp_dy, dy);
}
}
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetGradXOutGetGradYOut 两个函数太像了,可以压缩成一个函数

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if (dx->dims() == dout->dims() && dy->dims() == dout->dims()) {
outs = {dx, dy};
}
if (dx->dims() != dout->dims() && dy->dims() == dout->dims()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该是 else if 吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

const auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
if (dx != nullptr && dy != nullptr) {
GetGradXYOut<T>(dev_ctx, axis, x, y, out, dout, dx, dy,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉可以封装成一个,用不到的参数传空指针就好,剩下的只有functor不同了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

传入空指针的话,在GetGradXYOut函数中需要多次判断指针是否为空,代码可读性不太高

@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

头文件已经删除

const auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
if (dx != nullptr && dy != nullptr) {
GetGradXYOut<T>(dev_ctx, axis, x, y, out, dout, dx, dy,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

传入空指针的话,在GetGradXYOut函数中需要多次判断指针是否为空,代码可读性不太高

if (dx->dims() == dout->dims() && dy->dims() == dout->dims()) {
outs = {dx, dy};
}
if (dx->dims() != dout->dims() && dy->dims() == dout->dims()) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

ReduceWrapper<T>(dev_ctx, axis, &tmp_dy, dy);
}
}
#endif
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

@JamesLim-sy JamesLim-sy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with this pr yet, but there still are some modification advices needed to be discussed with other pr reviewers. May you discuss with other reviewers as soon as possible. And please not use force-push once reviewing process inits.

Copy link
Contributor

@ZzSean ZzSean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@PaddlePaddle PaddlePaddle deleted a comment from Zjq9409 Jan 4, 2022
@JamesLim-sy JamesLim-sy merged commit 55cd9cb into PaddlePaddle:develop Jan 5, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants