-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implementation of broadcast div backward by reduce #38044
Changes from 11 commits
d3173f8
c6cef2e
9265a8d
080bf95
f0f1cf3
8c43581
b1f58dc
e07e54e
3594f6b
7adf371
560ed45
2920824
476c797
8259c34
d2f3776
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ limitations under the License. */ | |
|
||
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h" | ||
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" | ||
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" | ||
#include "paddle/fluid/platform/complex.h" | ||
#include "paddle/fluid/platform/float16.h" | ||
|
||
|
@@ -29,13 +30,11 @@ static __global__ void SimpleElemwiseDivGradCUDAKernel(const T* x, const T* y, | |
const T* dout, | ||
int64_t size, T* dx, | ||
T* dy) { | ||
int col = blockIdx.x * blockDim.x + threadIdx.x; | ||
|
||
while (col < size) { | ||
T o = dout[col]; | ||
dx[col] = o / y[col]; | ||
dy[col] = -o * out[col] / y[col]; | ||
col += blockDim.x * gridDim.x; | ||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个函数还有必要吗? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 该函数目前已经删掉,走多输出分支 |
||
i += blockDim.x * gridDim.x) { | ||
T o = dout[i]; | ||
dx[i] = o / y[i]; | ||
dy[i] = -o * out[i] / y[i]; | ||
} | ||
} | ||
|
||
|
@@ -48,16 +47,14 @@ SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex<float>>( | |
const paddle::platform::complex<float>* dout, int64_t size, | ||
paddle::platform::complex<float>* dx, | ||
paddle::platform::complex<float>* dy) { | ||
int col = blockIdx.x * blockDim.x + threadIdx.x; | ||
|
||
while (col < size) { | ||
paddle::platform::complex<float> o = dout[col]; | ||
paddle::platform::complex<float> y_conj(y[col].real, -y[col].imag); | ||
paddle::platform::complex<float> out_div_y_conj((out[col] / y[col]).real, | ||
-(out[col] / y[col]).imag); | ||
dx[col] = o / y_conj; | ||
dy[col] = -o * out_div_y_conj; | ||
col += blockDim.x * gridDim.x; | ||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; | ||
i += blockDim.x * gridDim.x) { | ||
paddle::platform::complex<float> o = dout[i]; | ||
paddle::platform::complex<float> y_conj(y[i].real, -y[i].imag); | ||
paddle::platform::complex<float> out_div_y_conj((out[i] / y[i]).real, | ||
-(out[i] / y[i]).imag); | ||
dx[i] = o / y_conj; | ||
dy[i] = -dout[i] * out_div_y_conj; | ||
} | ||
} | ||
|
||
|
@@ -70,27 +67,125 @@ SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex<double>>( | |
const paddle::platform::complex<double>* dout, int64_t size, | ||
paddle::platform::complex<double>* dx, | ||
paddle::platform::complex<double>* dy) { | ||
int col = blockIdx.x * blockDim.x + threadIdx.x; | ||
|
||
while (col < size) { | ||
paddle::platform::complex<double> o = dout[col]; | ||
paddle::platform::complex<double> y_conj(y[col].real, -y[col].imag); | ||
paddle::platform::complex<double> out_div_y_conj((out[col] / y[col]).real, | ||
-(out[col] / y[col]).imag); | ||
dx[col] = o / y_conj; | ||
dy[col] = -o * out_div_y_conj; | ||
col += blockDim.x * gridDim.x; | ||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; | ||
i += blockDim.x * gridDim.x) { | ||
paddle::platform::complex<double> o = dout[i]; | ||
paddle::platform::complex<double> y_conj(y[i].real, -y[i].imag); | ||
paddle::platform::complex<double> out_div_y_conj((out[i] / y[i]).real, | ||
-(out[i] / y[i]).imag); | ||
dx[i] = o / y_conj; | ||
dy[i] = -dout[i] * out_div_y_conj; | ||
} | ||
} | ||
|
||
template <typename T> | ||
void ReduceForDiv(const platform::CUDADeviceContext& dev_ctx, int axis, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个封装不止适用于div吧,是不是能改成ReduceWrapper() There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已经提出公共接口函数 |
||
const framework::Tensor* in, const framework::Tensor* out, | ||
framework::Tensor* src, framework::Tensor* dst) { | ||
std::vector<int> reduce_dims = GetReduceDim(in->dims(), out->dims(), axis); | ||
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>( | ||
*src, dst, kps::IdentityFunctor<T>(), reduce_dims, dev_ctx.stream()); | ||
} | ||
|
||
template <typename DeviceContext, typename T> | ||
typename std::enable_if< | ||
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type | ||
DefaultElementwiseDivGrad(const framework::ExecutionContext& ctx, | ||
const framework::Tensor* x, | ||
const framework::Tensor* y, | ||
const framework::Tensor* out, | ||
const framework::Tensor* dout, framework::Tensor* dx, | ||
framework::Tensor* dy) { | ||
int axis = ctx.Attr<int>("axis"); | ||
auto* dout_data = dout->data<T>(); | ||
dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. block_size 定义了但没有被使用 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已经删掉 |
||
const auto& dev_ctx = | ||
ctx.template device_context<platform::CUDADeviceContext>(); | ||
framework::Tensor tmp_dx; | ||
tmp_dx.mutable_data<T>(dout->dims(), ctx.GetPlace()); | ||
framework::Tensor tmp_dy; | ||
tmp_dy.mutable_data<T>(dout->dims(), ctx.GetPlace()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 并不是所有情况都需要使用临时Tensor吧? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 申请临时Tensor空间放在了if else分支中 |
||
if (dx != nullptr && dy != nullptr) { | ||
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mutable_data的结果不必传给指针(下文没用到指针),下同 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
auto* dy_data = dy->mutable_data<T>(ctx.GetPlace()); | ||
// For inplace strategy, dx will be stored in addr of dout, which makes | ||
// the result of dy wrong. | ||
if (dx->IsSharedBufferWith(*dout)) { | ||
dx->clear(); | ||
dx->mutable_data<T>(x->dims(), ctx.GetPlace()); | ||
} | ||
|
||
std::vector<const framework::Tensor*> ins = {dout, out, y}; | ||
std::vector<framework::Tensor*> outs; | ||
if (dx->dims() == dout->dims() && dy->dims() == dout->dims()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个分支可以删掉,因为这种情况下根本不会进到这个接口 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 删掉了原来相同dims 的CUDA函数,相同dims走该分支,所以保留 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个分支重复,可以去掉,当(dx->dims() == dy->dims())时,外层调用不会执行到该函数。 |
||
outs = {dx, dy}; | ||
} else if (dx->dims() != dout->dims() && dy->dims() == dout->dims()) { | ||
outs = {&tmp_dx, dy}; | ||
} else if (dx->dims() == dout->dims() && dy->dims() != dout->dims()) { | ||
outs = {dx, &tmp_dy}; | ||
} else if (dx->dims() != dout->dims() && dy->dims() != dout->dims()) { | ||
outs = {&tmp_dx, &tmp_dy}; | ||
} | ||
|
||
auto functor = DivGradXYFunctor<T, T>(); | ||
LaunchElementwiseCudaKernel<ElementwiseType::kTernary, T, T, | ||
decltype(functor), 2>(dev_ctx, ins, &outs, axis, | ||
functor); | ||
|
||
if (dx->dims() != dout->dims() && dy->dims() == dout->dims()) { | ||
ReduceForDiv<T>(dev_ctx, axis, x, out, &tmp_dx, dx); | ||
} else if (dx->dims() == dout->dims() && dy->dims() != dout->dims()) { | ||
ReduceForDiv<T>(dev_ctx, axis, y, out, &tmp_dy, dy); | ||
} else if (dx->dims() != dout->dims() && dy->dims() != dout->dims()) { | ||
ReduceForDiv<T>(dev_ctx, axis, x, out, &tmp_dx, dx); | ||
ReduceForDiv<T>(dev_ctx, axis, y, out, &tmp_dy, dy); | ||
} | ||
} else if (dx != nullptr && dy == nullptr) { | ||
Zjq9409 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace()); | ||
if (dx->IsSharedBufferWith(*dout)) { | ||
dx->clear(); | ||
dx->mutable_data<T>(x->dims(), ctx.GetPlace()); | ||
} | ||
|
||
std::vector<const framework::Tensor*> ins = {dout, y}; | ||
std::vector<framework::Tensor*> outs; | ||
if (dx->dims() != dout->dims()) { | ||
outs = {&tmp_dx}; | ||
} else { | ||
outs = {dx}; | ||
} | ||
|
||
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>( | ||
dev_ctx, ins, &outs, axis, DivGradFunctor<T>()); | ||
if (dx->dims() != dout->dims()) { | ||
ReduceForDiv<T>(dev_ctx, axis, x, out, &tmp_dx, dx); | ||
} | ||
} else if (dy != nullptr && dx == nullptr) { | ||
auto* dy_data = dy->mutable_data<T>(ctx.GetPlace()); | ||
|
||
std::vector<const framework::Tensor*> ins = {dout, out, y}; | ||
std::vector<framework::Tensor*> outs; | ||
if (dy->dims() != dout->dims()) { | ||
outs = {&tmp_dy}; | ||
} else { | ||
outs = {dy}; | ||
} | ||
|
||
LaunchElementwiseCudaKernel<ElementwiseType::kTernary, T, T>( | ||
dev_ctx, ins, &outs, axis, DivGradYFunctor<T>()); | ||
if (dy->dims() != dout->dims()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||
ReduceForDiv<T>(dev_ctx, axis, y, out, &tmp_dy, dy); | ||
} | ||
} | ||
} | ||
|
||
template <typename DeviceContext, typename T> | ||
typename std::enable_if< | ||
std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type | ||
elementwise_div_grad(const framework::ExecutionContext& ctx, | ||
const framework::Tensor* x, const framework::Tensor* y, | ||
const framework::Tensor* out, | ||
const framework::Tensor* dout, framework::Tensor* dx, | ||
framework::Tensor* dy) { | ||
ElementwiseDivGrad(const framework::ExecutionContext& ctx, | ||
const framework::Tensor* x, const framework::Tensor* y, | ||
const framework::Tensor* out, const framework::Tensor* dout, | ||
framework::Tensor* dx, framework::Tensor* dy) { | ||
dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1); | ||
auto size = x->numel(); | ||
dim3 grid_size = | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -111,26 +111,47 @@ struct DivDoubleDY { | |
template <typename DeviceContext, typename T> | ||
typename std::enable_if< | ||
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type | ||
elementwise_div_grad(const framework::ExecutionContext& ctx, | ||
const framework::Tensor* x, const framework::Tensor* y, | ||
const framework::Tensor* out, | ||
const framework::Tensor* dout, framework::Tensor* dx, | ||
framework::Tensor* dy) { | ||
DefaultElementwiseDivGrad(const framework::ExecutionContext& ctx, | ||
const framework::Tensor* x, | ||
const framework::Tensor* y, | ||
const framework::Tensor* out, | ||
const framework::Tensor* dout, framework::Tensor* dx, | ||
framework::Tensor* dy) { | ||
int axis = ctx.Attr<int>("axis"); | ||
|
||
ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivGradDY<T>>( | ||
ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX<T>(), DivGradDY<T>()); | ||
} | ||
|
||
template <typename DeviceContext, typename T> | ||
typename std::enable_if< | ||
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type | ||
ElementwiseDivGrad(const framework::ExecutionContext& ctx, | ||
const framework::Tensor* x, const framework::Tensor* y, | ||
const framework::Tensor* out, const framework::Tensor* dout, | ||
framework::Tensor* dx, framework::Tensor* dy) { | ||
DefaultElementwiseDivGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy); | ||
} | ||
|
||
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) | ||
template <typename DeviceContext, typename T> | ||
// cuda definition | ||
typename std::enable_if< | ||
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type | ||
DefaultElementwiseDivGrad(const framework::ExecutionContext& ctx, | ||
const framework::Tensor* x, | ||
const framework::Tensor* y, | ||
const framework::Tensor* out, | ||
const framework::Tensor* dout, framework::Tensor* dx, | ||
framework::Tensor* dy); | ||
|
||
template <typename DeviceContext, typename T> | ||
typename std::enable_if< | ||
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type | ||
elementwise_div_grad(const framework::ExecutionContext& ctx, | ||
const framework::Tensor* x, const framework::Tensor* y, | ||
const framework::Tensor* out, | ||
const framework::Tensor* dout, framework::Tensor* dx, | ||
framework::Tensor* dy); | ||
ElementwiseDivGrad(const framework::ExecutionContext& ctx, | ||
const framework::Tensor* x, const framework::Tensor* y, | ||
const framework::Tensor* out, const framework::Tensor* dout, | ||
framework::Tensor* dx, framework::Tensor* dy); | ||
#endif | ||
|
||
template <typename DeviceContext, typename T> | ||
|
@@ -146,14 +167,11 @@ class ElementwiseDivGradKernel : public ElemwiseGradKernel<T> { | |
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); | ||
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); | ||
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); | ||
int axis = ctx.Attr<int>("axis"); | ||
|
||
if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DefaultElementwiseDivGrad已经包括这个分支了,可以删除 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
elementwise_div_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy); | ||
ElementwiseDivGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy); | ||
} else { | ||
ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivGradDY<T>>( | ||
ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX<T>(), | ||
DivGradDY<T>()); | ||
DefaultElementwiseDivGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy); | ||
} | ||
} | ||
}; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,8 @@ limitations under the License. */ | |
|
||
#pragma once | ||
|
||
#include "paddle/fluid/framework/array.h" | ||
#include "paddle/fluid/platform/complex.h" | ||
#include "paddle/fluid/platform/enforce.h" | ||
#include "paddle/fluid/platform/float16.h" | ||
#include "paddle/fluid/platform/hostdevice.h" | ||
|
@@ -113,6 +115,70 @@ struct MinFunctor { | |
} | ||
}; | ||
|
||
template <typename T> | ||
using Complex = paddle::platform::complex<T>; | ||
|
||
template <typename InT, typename OutT> | ||
struct DivGradXYFunctor { | ||
inline HOSTDEVICE paddle::framework::Array<OutT, 2> operator()(InT a, InT b, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 只读参数传 const reference,下同 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
InT c) { | ||
// dx = dout / y | ||
// dy = - dout * out / y | ||
paddle::framework::Array<OutT, 2> outs; | ||
outs[0] = a / c; | ||
outs[1] = -a * b / c; | ||
return outs; | ||
} | ||
}; | ||
|
||
template <typename InT, typename OutT> | ||
struct DivGradXYFunctor<Complex<InT>, Complex<OutT>> { | ||
inline HOSTDEVICE paddle::framework::Array<Complex<OutT>, 2> operator()( | ||
Complex<InT> a, Complex<InT> b, Complex<InT> c) { | ||
paddle::framework::Array<Complex<OutT>, 2> outs; | ||
Complex<InT> c_conj(c.real, -c.imag); | ||
Complex<InT> out_div_y_conj((b / c).real, -(b / c).imag); | ||
outs[0] = a / c_conj; | ||
outs[1] = -a * out_div_y_conj; | ||
return outs; | ||
} | ||
}; | ||
|
||
// Float div grad | ||
template <typename T> | ||
struct DivGradFunctor { | ||
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a / b; } | ||
}; | ||
|
||
// Complex div grad | ||
template <typename T> | ||
struct DivGradFunctor<Complex<T>> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里是不是跟GradY对应起来写成GradX There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
inline HOSTDEVICE Complex<T> operator()(const Complex<T>& a, | ||
const Complex<T>& b) const { | ||
Complex<T> b_conj(b.real, -b.imag); | ||
return a / b_conj; | ||
} | ||
}; | ||
|
||
// Float mul and div | ||
template <typename T> | ||
struct DivGradYFunctor { | ||
inline HOSTDEVICE T operator()(const T& a, const T& b, const T& c) const { | ||
return -a * b / c; | ||
} | ||
}; | ||
|
||
// Complex mul and div | ||
template <typename T> | ||
struct DivGradYFunctor<Complex<T>> { | ||
inline HOSTDEVICE Complex<T> operator()(const Complex<T>& a, | ||
const Complex<T>& b, | ||
const Complex<T>& c) const { | ||
Complex<T> out_div_y_conj((b / c).real, -(b / c).imag); | ||
return -a * out_div_y_conj; | ||
} | ||
}; | ||
|
||
// Fmax | ||
template <typename T> | ||
struct FMaxFunctor { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
头文件已经删除