-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implementation of broadcast div backward by reduce #38044
Changes from 12 commits
d3173f8
c6cef2e
9265a8d
080bf95
f0f1cf3
8c43581
b1f58dc
e07e54e
3594f6b
7adf371
560ed45
2920824
476c797
8259c34
d2f3776
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and | |
limitations under the License. */ | ||
|
||
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h" | ||
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" | ||
#include "paddle/fluid/platform/complex.h" | ||
#include "paddle/fluid/platform/float16.h" | ||
|
||
namespace ops = paddle::operators; | ||
|
@@ -23,83 +21,24 @@ namespace plat = paddle::platform; | |
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename T> | ||
static __global__ void SimpleElemwiseDivGradCUDAKernel(const T* x, const T* y, | ||
const T* out, | ||
const T* dout, | ||
int64_t size, T* dx, | ||
T* dy) { | ||
int col = blockIdx.x * blockDim.x + threadIdx.x; | ||
|
||
while (col < size) { | ||
T o = dout[col]; | ||
dx[col] = o / y[col]; | ||
dy[col] = -o * out[col] / y[col]; | ||
col += blockDim.x * gridDim.x; | ||
} | ||
} | ||
|
||
template <> | ||
__global__ void | ||
SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex<float>>( | ||
const paddle::platform::complex<float>* x, | ||
const paddle::platform::complex<float>* y, | ||
const paddle::platform::complex<float>* out, | ||
const paddle::platform::complex<float>* dout, int64_t size, | ||
paddle::platform::complex<float>* dx, | ||
paddle::platform::complex<float>* dy) { | ||
int col = blockIdx.x * blockDim.x + threadIdx.x; | ||
|
||
while (col < size) { | ||
paddle::platform::complex<float> o = dout[col]; | ||
paddle::platform::complex<float> y_conj(y[col].real, -y[col].imag); | ||
paddle::platform::complex<float> out_div_y_conj((out[col] / y[col]).real, | ||
-(out[col] / y[col]).imag); | ||
dx[col] = o / y_conj; | ||
dy[col] = -o * out_div_y_conj; | ||
col += blockDim.x * gridDim.x; | ||
} | ||
} | ||
|
||
template <> | ||
__global__ void | ||
SimpleElemwiseDivGradCUDAKernel<paddle::platform::complex<double>>( | ||
const paddle::platform::complex<double>* x, | ||
const paddle::platform::complex<double>* y, | ||
const paddle::platform::complex<double>* out, | ||
const paddle::platform::complex<double>* dout, int64_t size, | ||
paddle::platform::complex<double>* dx, | ||
paddle::platform::complex<double>* dy) { | ||
int col = blockIdx.x * blockDim.x + threadIdx.x; | ||
|
||
while (col < size) { | ||
paddle::platform::complex<double> o = dout[col]; | ||
paddle::platform::complex<double> y_conj(y[col].real, -y[col].imag); | ||
paddle::platform::complex<double> out_div_y_conj((out[col] / y[col]).real, | ||
-(out[col] / y[col]).imag); | ||
dx[col] = o / y_conj; | ||
dy[col] = -o * out_div_y_conj; | ||
col += blockDim.x * gridDim.x; | ||
} | ||
} | ||
|
||
template <typename DeviceContext, typename T> | ||
typename std::enable_if< | ||
std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type | ||
elementwise_div_grad(const framework::ExecutionContext& ctx, | ||
const framework::Tensor* x, const framework::Tensor* y, | ||
const framework::Tensor* out, | ||
const framework::Tensor* dout, framework::Tensor* dx, | ||
framework::Tensor* dy) { | ||
dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1); | ||
auto size = x->numel(); | ||
dim3 grid_size = | ||
dim3((size + ELEMENTWISE_BLOCK_SIZE - 1) / ELEMENTWISE_BLOCK_SIZE, 1); | ||
SimpleElemwiseDivGradCUDAKernel< | ||
T><<<grid_size, block_size, 0, | ||
ctx.template device_context<plat::CUDADeviceContext>().stream()>>>( | ||
x->data<T>(), y->data<T>(), out->data<T>(), dout->data<T>(), size, | ||
dx->mutable_data<T>(ctx.GetPlace()), dy->mutable_data<T>(ctx.GetPlace())); | ||
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type | ||
ElementwiseDivGrad(const framework::ExecutionContext& ctx, | ||
const framework::Tensor* x, const framework::Tensor* y, | ||
const framework::Tensor* out, const framework::Tensor* dout, | ||
framework::Tensor* dx, framework::Tensor* dy) { | ||
int axis = ctx.Attr<int>("axis"); | ||
const auto& dev_ctx = | ||
ctx.template device_context<platform::CUDADeviceContext>(); | ||
if (dx != nullptr && dy != nullptr) { | ||
GetGradXYOut<T>(dev_ctx, axis, x, y, out, dout, dx, dy, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 感觉可以封装成一个,用不到的参数传空指针就好,剩下的只有functor不同了 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 传入空指针的话,在GetGradXYOut函数中需要多次判断指针是否为空,代码可读性不太高 |
||
DivGradXYFunctor<T, T>()); | ||
} else if (dx != nullptr && dy == nullptr) { | ||
Zjq9409 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
GetGradXOut<T>(dev_ctx, axis, x, y, dout, dx, DivGradXFunctor<T>()); | ||
} else if (dy != nullptr && dx == nullptr) { | ||
GetGradYOut<T>(dev_ctx, axis, y, out, dout, dy, DivGradYFunctor<T>()); | ||
} | ||
} | ||
|
||
} // namespace operators | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,6 +43,7 @@ limitations under the License. */ | |
#include <thrust/iterator/iterator_adaptor.h> | ||
|
||
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" | ||
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" | ||
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" | ||
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" | ||
|
||
|
@@ -2619,5 +2620,112 @@ static inline std::vector<int> GetReduceDim(const framework::DDim &in, | |
} | ||
return dims; | ||
} | ||
|
||
#if defined(__NVCC__) || defined(__HIPCC__) | ||
template <typename T> | ||
void ReduceWrapper(const platform::CUDADeviceContext &dev_ctx, int axis, | ||
framework::Tensor *src, framework::Tensor *dst) { | ||
std::vector<int> reduce_dims = GetReduceDim(dst->dims(), src->dims(), axis); | ||
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>( | ||
*src, dst, kps::IdentityFunctor<T>(), reduce_dims, dev_ctx.stream()); | ||
} | ||
|
||
template <typename T, typename Functor> | ||
void GetGradXYOut(const platform::CUDADeviceContext &dev_ctx, int axis, | ||
const framework::Tensor *x, const framework::Tensor *y, | ||
const framework::Tensor *out, const framework::Tensor *dout, | ||
framework::Tensor *dx, framework::Tensor *dy, Functor func) { | ||
framework::Tensor tmp_dx; | ||
framework::Tensor tmp_dy; | ||
dx->mutable_data<T>(platform::CUDAPlace()); | ||
dy->mutable_data<T>(platform::CUDAPlace()); | ||
// For inplace strategy, dx will be stored in addr of dout, which makes | ||
// the result of dy wrong. | ||
if (dx->IsSharedBufferWith(*dout)) { | ||
dx->clear(); | ||
dx->mutable_data<T>(x->dims(), platform::CUDAPlace()); | ||
} | ||
|
||
std::vector<const framework::Tensor *> ins = {dout, out, y}; | ||
std::vector<framework::Tensor *> outs; | ||
if (dx->dims() == dout->dims() && dy->dims() == dout->dims()) { | ||
outs = {dx, dy}; | ||
} | ||
if (dx->dims() != dout->dims() && dy->dims() == dout->dims()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里应该是 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
tmp_dx.mutable_data<T>(dout->dims(), platform::CUDAPlace()); | ||
outs = {&tmp_dx, dy}; | ||
} else if (dx->dims() == dout->dims() && dy->dims() != dout->dims()) { | ||
tmp_dy.mutable_data<T>(dout->dims(), platform::CUDAPlace()); | ||
outs = {dx, &tmp_dy}; | ||
} else if (dx->dims() != dout->dims() && dy->dims() != dout->dims()) { | ||
tmp_dy.mutable_data<T>(dout->dims(), platform::CUDAPlace()); | ||
tmp_dx.mutable_data<T>(dout->dims(), platform::CUDAPlace()); | ||
outs = {&tmp_dx, &tmp_dy}; | ||
} | ||
|
||
LaunchElementwiseCudaKernel<ElementwiseType::kTernary, T, T, decltype(func), | ||
2>(dev_ctx, ins, &outs, axis, func); | ||
|
||
if (dx->dims() != dout->dims() && dy->dims() == dout->dims()) { | ||
ReduceWrapper<T>(dev_ctx, axis, &tmp_dx, dx); | ||
} else if (dx->dims() == dout->dims() && dy->dims() != dout->dims()) { | ||
ReduceWrapper<T>(dev_ctx, axis, &tmp_dy, dy); | ||
} else if (dx->dims() != dout->dims() && dy->dims() != dout->dims()) { | ||
ReduceWrapper<T>(dev_ctx, axis, &tmp_dx, dx); | ||
ReduceWrapper<T>(dev_ctx, axis, &tmp_dy, dy); | ||
} | ||
} | ||
|
||
template <typename T, typename Functor> | ||
void GetGradXOut(const platform::CUDADeviceContext &dev_ctx, int axis, | ||
const framework::Tensor *x, const framework::Tensor *y, | ||
const framework::Tensor *dout, framework::Tensor *dx, | ||
Functor func) { | ||
framework::Tensor tmp_dx; | ||
dx->mutable_data<T>(platform::CUDAPlace()); | ||
if (dx->IsSharedBufferWith(*dout)) { | ||
dx->clear(); | ||
dx->mutable_data<T>(x->dims(), platform::CUDAPlace()); | ||
} | ||
std::vector<const framework::Tensor *> ins = {dout, y}; | ||
std::vector<framework::Tensor *> outs; | ||
if (dx->dims() != dout->dims()) { | ||
tmp_dx.mutable_data<T>(dout->dims(), platform::CUDAPlace()); | ||
outs = {&tmp_dx}; | ||
} else { | ||
outs = {dx}; | ||
} | ||
|
||
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>( | ||
dev_ctx, ins, &outs, axis, func); | ||
if (dx->dims() != dout->dims()) { | ||
ReduceWrapper<T>(dev_ctx, axis, &tmp_dx, dx); | ||
} | ||
} | ||
|
||
template <typename T, typename Functor> | ||
void GetGradYOut(const platform::CUDADeviceContext &dev_ctx, int axis, | ||
const framework::Tensor *y, const framework::Tensor *out, | ||
const framework::Tensor *dout, framework::Tensor *dy, | ||
Functor func) { | ||
framework::Tensor tmp_dy; | ||
dy->mutable_data<T>(platform::CUDAPlace()); | ||
std::vector<const framework::Tensor *> ins = {dout, out, y}; | ||
std::vector<framework::Tensor *> outs; | ||
if (dy->dims() != dout->dims()) { | ||
tmp_dy.mutable_data<T>(dout->dims(), platform::CUDAPlace()); | ||
outs = {&tmp_dy}; | ||
} else { | ||
outs = {dy}; | ||
} | ||
|
||
LaunchElementwiseCudaKernel<ElementwiseType::kTernary, T, T>( | ||
dev_ctx, ins, &outs, axis, func); | ||
if (dy->dims() != dout->dims()) { | ||
ReduceWrapper<T>(dev_ctx, axis, &tmp_dy, dy); | ||
} | ||
} | ||
#endif | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
||
} // namespace operators | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
头文件已经删除