Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implementation of broadcast div backward by reduce #38044

Merged
merged 15 commits into from
Jan 5, 2022
109 changes: 31 additions & 78 deletions paddle/fluid/operators/elementwise/elementwise_div_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,93 +13,46 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"

namespace ops = paddle::operators;
namespace plat = paddle::platform;

namespace paddle {
namespace operators {

template <typename T>
static __global__ void SimpleElemwiseDivGradCUDAKernel(const T* x, const T* y,
const T* out,
const T* dout,
int64_t size, T* dx,
T* dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x;

while (col < size) {
T o = dout[col];
dx[col] = o / y[col];
dy[col] = -o * out[col] / y[col];
col += blockDim.x * gridDim.x;
}
}

template <>
__global__ void
SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex<float>>(
const paddle::platform::complex<float>* x,
const paddle::platform::complex<float>* y,
const paddle::platform::complex<float>* out,
const paddle::platform::complex<float>* dout, int64_t size,
paddle::platform::complex<float>* dx,
paddle::platform::complex<float>* dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x;

while (col < size) {
paddle::platform::complex<float> o = dout[col];
paddle::platform::complex<float> y_conj(y[col].real, -y[col].imag);
paddle::platform::complex<float> out_div_y_conj((out[col] / y[col]).real,
-(out[col] / y[col]).imag);
dx[col] = o / y_conj;
dy[col] = -o * out_div_y_conj;
col += blockDim.x * gridDim.x;
}
}

template <>
__global__ void
SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex<double>>(
const paddle::platform::complex<double>* x,
const paddle::platform::complex<double>* y,
const paddle::platform::complex<double>* out,
const paddle::platform::complex<double>* dout, int64_t size,
paddle::platform::complex<double>* dx,
paddle::platform::complex<double>* dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x;

while (col < size) {
paddle::platform::complex<double> o = dout[col];
paddle::platform::complex<double> y_conj(y[col].real, -y[col].imag);
paddle::platform::complex<double> out_div_y_conj((out[col] / y[col]).real,
-(out[col] / y[col]).imag);
dx[col] = o / y_conj;
dy[col] = -o * out_div_y_conj;
col += blockDim.x * gridDim.x;
}
}

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type
elementwise_div_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dy) {
dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1);
auto size = x->numel();
dim3 grid_size =
dim3((size + ELEMENTWISE_BLOCK_SIZE - 1) / ELEMENTWISE_BLOCK_SIZE, 1);
SimpleElemwiseDivGradCUDAKernel<
T><<<grid_size, block_size, 0,
ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
x->data<T>(), y->data<T>(), out->data<T>(), dout->data<T>(), size,
dx->mutable_data<T>(ctx.GetPlace()), dy->mutable_data<T>(ctx.GetPlace()));
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
ElementwiseDivGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
const auto& dev_ctx = ctx.template device_context<DeviceContext>();
const auto place = ctx.GetPlace();
if (dx != nullptr && dy != nullptr) {
dx->mutable_data<T>(place);
if (dx->IsSharedBufferWith(*dout)) {
dx->clear();
dx->mutable_data<T>(x->dims(), place);
}
std::vector<const framework::Tensor*> ins = {dout, out, y};
GetGradXAndYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dx, dy, DivGradXYFunctor<T, T>());
} else if (dx != nullptr && dy == nullptr) {
Zjq9409 marked this conversation as resolved.
Show resolved Hide resolved
dx->mutable_data<T>(place);
if (dx->IsSharedBufferWith(*dout)) {
dx->clear();
dx->mutable_data<T>(x->dims(), place);
}
std::vector<const framework::Tensor*> ins = {dout, y};
GetGradXOrYOut<ElementwiseType::kBinary, T>(dev_ctx, place, axis, ins, dout,
dx, DivGradXFunctor<T>());
} else if (dy != nullptr && dx == nullptr) {
std::vector<const framework::Tensor*> ins = {dout, out, y};
GetGradXOrYOut<ElementwiseType::kTernary, T>(
dev_ctx, place, axis, ins, dout, dy, DivGradYFunctor<T>());
}
}

} // namespace operators
Expand Down
29 changes: 10 additions & 19 deletions paddle/fluid/operators/elementwise/elementwise_div_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,26 +111,24 @@ struct DivDoubleDY {
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_div_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dy) {
ElementwiseDivGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");

ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivGradDY<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX<T>(), DivGradDY<T>());
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// cuda definition
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
elementwise_div_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dy);
ElementwiseDivGrad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out, const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy);
#endif

template <typename DeviceContext, typename T>
Expand All @@ -146,15 +144,8 @@ class ElementwiseDivGradKernel : public ElemwiseGradKernel<T> {
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");

if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
elementwise_div_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
} else {
ElemwiseGradCompute<DeviceContext, T, DivGradDX<T>, DivGradDY<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX<T>(),
DivGradDY<T>());
}
ElementwiseDivGrad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
}
};

Expand Down
67 changes: 67 additions & 0 deletions paddle/fluid/operators/elementwise/elementwise_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ limitations under the License. */

#pragma once

#include "paddle/fluid/framework/array.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/hostdevice.h"
Expand Down Expand Up @@ -113,6 +115,71 @@ struct MinFunctor {
}
};

template <typename T>
using Complex = paddle::platform::complex<T>;

template <typename InT, typename OutT>
struct DivGradXYFunctor {
inline HOSTDEVICE paddle::framework::Array<OutT, 2> operator()(const InT a,
const InT b,
const InT c) {
// dx = dout / y
// dy = - dout * out / y
paddle::framework::Array<OutT, 2> outs;
outs[0] = a / c;
outs[1] = -a * b / c;
return outs;
}
};

template <typename InT, typename OutT>
struct DivGradXYFunctor<Complex<InT>, Complex<OutT>> {
inline HOSTDEVICE paddle::framework::Array<Complex<OutT>, 2> operator()(
const Complex<InT> a, const Complex<InT> b, const Complex<InT> c) {
paddle::framework::Array<Complex<OutT>, 2> outs;
Complex<InT> c_conj(c.real, -c.imag);
Complex<InT> out_div_c_conj((b / c).real, -(b / c).imag);
outs[0] = a / c_conj;
outs[1] = -a * out_div_c_conj;
return outs;
}
};

// Float div grad
template <typename T>
struct DivGradXFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a / b; }
};

// Complex div grad
template <typename T>
struct DivGradXFunctor<Complex<T>> {
inline HOSTDEVICE Complex<T> operator()(const Complex<T>& a,
const Complex<T>& b) const {
Complex<T> b_conj(b.real, -b.imag);
return a / b_conj;
}
};

// Float mul and div
template <typename T>
struct DivGradYFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b, const T& c) const {
return -a * b / c;
}
};

// Complex mul and div
template <typename T>
struct DivGradYFunctor<Complex<T>> {
inline HOSTDEVICE Complex<T> operator()(const Complex<T>& a,
const Complex<T>& b,
const Complex<T>& c) const {
Complex<T> out_div_c_conj((b / c).real, -(b / c).imag);
return -a * out_div_c_conj;
}
};

// Fmax
template <typename T>
struct FMaxFunctor {
Expand Down
73 changes: 73 additions & 0 deletions paddle/fluid/operators/elementwise/elementwise_op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ limitations under the License. */
#include <thrust/iterator/iterator_adaptor.h>

#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"

Expand Down Expand Up @@ -2619,5 +2620,77 @@ static inline std::vector<int> GetReduceDim(const framework::DDim &in,
}
return dims;
}

#if defined(__NVCC__) || defined(__HIPCC__)
template <typename T>
void ReduceWrapper(const platform::CUDADeviceContext &dev_ctx, int axis,
framework::Tensor *src, framework::Tensor *dst) {
std::vector<int> reduce_dims = GetReduceDim(dst->dims(), src->dims(), axis);
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*src, dst, kps::IdentityFunctor<T>(), reduce_dims, dev_ctx.stream());
}

template <ElementwiseType ET, typename T, typename Functor>
void GetGradXAndYOut(const platform::CUDADeviceContext &dev_ctx,
const platform::Place &place, int axis,
std::vector<const framework::Tensor *> ins,
const framework::Tensor *dout, framework::Tensor *dx,
framework::Tensor *dy, Functor func) {
framework::Tensor tmp_dx;
framework::Tensor tmp_dy;
dy->mutable_data<T>(place);
std::vector<framework::Tensor *> outs;
if (dx->dims() == dout->dims() && dy->dims() == dout->dims()) {
outs = {dx, dy};
} else if (dx->dims() != dout->dims() && dy->dims() == dout->dims()) {
tmp_dx.mutable_data<T>(dout->dims(), place);
outs = {&tmp_dx, dy};
} else if (dx->dims() == dout->dims() && dy->dims() != dout->dims()) {
tmp_dy.mutable_data<T>(dout->dims(), place);
outs = {dx, &tmp_dy};
} else if (dx->dims() != dout->dims() && dy->dims() != dout->dims()) {
tmp_dy.mutable_data<T>(dout->dims(), place);
tmp_dx.mutable_data<T>(dout->dims(), place);
outs = {&tmp_dx, &tmp_dy};
}

LaunchElementwiseCudaKernel<ET, T, T, decltype(func), 2>(dev_ctx, ins, &outs,
axis, func);

if (dx->dims() != dout->dims() && dy->dims() == dout->dims()) {
ReduceWrapper<T>(dev_ctx, axis, &tmp_dx, dx);
} else if (dx->dims() == dout->dims() && dy->dims() != dout->dims()) {
ReduceWrapper<T>(dev_ctx, axis, &tmp_dy, dy);
} else if (dx->dims() != dout->dims() && dy->dims() != dout->dims()) {
ReduceWrapper<T>(dev_ctx, axis, &tmp_dx, dx);
ReduceWrapper<T>(dev_ctx, axis, &tmp_dy, dy);
}
}

template <ElementwiseType ET, typename T, typename Functor>
void GetGradXOrYOut(const platform::CUDADeviceContext &dev_ctx,
const platform::Place &place, int axis,
std::vector<const framework::Tensor *> ins,
const framework::Tensor *dout, framework::Tensor *dxy,
Functor func) {
framework::Tensor tmp_dxy;
dxy->mutable_data<T>(place);

std::vector<framework::Tensor *> outs;
if (dxy->dims() != dout->dims()) {
tmp_dxy.mutable_data<T>(dout->dims(), place);
outs = {&tmp_dxy};
} else {
outs = {dxy};
}

LaunchElementwiseCudaKernel<ET, T, T>(dev_ctx, ins, &outs, axis, func);
if (dxy->dims() != dout->dims()) {
ReduceWrapper<T>(dev_ctx, axis, &tmp_dxy, dxy);
}
}

#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetGradXOutGetGradYOut 两个函数太像了,可以压缩成一个函数

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


} // namespace operators
} // namespace paddle