Skip to content

Commit

Permalink
[ROCM] fix softmax_with_cross_entropy_op (PaddlePaddle#31982)
Browse files Browse the repository at this point in the history
  • Loading branch information
ronny1996 committed Apr 2, 2021
1 parent 3560e68 commit e12798d
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 7 deletions.
11 changes: 8 additions & 3 deletions paddle/fluid/operators/math/cross_entropy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,23 @@ class CrossEntropyFunctor<platform::CUDADeviceContext, T> {

int batch_size = prob->dims()[0];
int class_num = prob->dims()[1];
#ifdef __HIPCC__
constexpr int kMaxBlockDim = 256;
#else
constexpr int kMaxBlockDim = 512;
#endif

if (softLabel) {
const T* label_data = labels->data<T>();
int block = class_num > 512
? 512
int block = class_num > kMaxBlockDim
? kMaxBlockDim
: pow(2, static_cast<int>(std::log2(class_num)));

SoftCrossEntropyKernel<T><<<batch_size, block, 0, ctx.stream()>>>(
loss_data, prob_data, label_data, class_num);
} else {
const int64_t* label_data = labels->data<int64_t>();
int block = 512;
int block = kMaxBlockDim;
int grid = (batch_size + block - 1) / block;
CrossEntropyKernel<T><<<grid, block, 0, ctx.stream()>>>(
loss_data, prob_data, label_data, batch_size, class_num,
Expand Down
10 changes: 6 additions & 4 deletions paddle/fluid/operators/math/softmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ void SoftmaxCUDNNFunctor<T>::operator()(
xDesc.descriptor<T>(layout, cudnn_tensor_dims);
miopenTensorDescriptor_t cudnn_y_desc =
xDesc.descriptor<T>(layout, cudnn_tensor_dims);
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxForward(
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxForward_V2(
context.cudnn_handle(), CudnnDataType<T>::kOne(), cudnn_x_desc,
X->data<T>(), CudnnDataType<T>::kZero(), cudnn_y_desc,
Y->mutable_data<T>(context.GetPlace())));
Y->mutable_data<T>(context.GetPlace()), MIOPEN_SOFTMAX_ACCURATE,
MIOPEN_SOFTMAX_MODE_INSTANCE));
#else
cudnnTensorDescriptor_t cudnn_x_desc =
xDesc.descriptor<T>(layout, cudnn_tensor_dims);
Expand Down Expand Up @@ -96,11 +97,12 @@ void SoftmaxGradCUDNNFunctor<T>::operator()(
dxDesc.descriptor<T>(layout, cudnn_tensor_dims);
miopenTensorDescriptor_t cudnn_ygrad_desc =
dyDesc.descriptor<T>(layout, cudnn_tensor_dims);
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxBackward(
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxBackward_V2(
context.cudnn_handle(), CudnnDataType<T>::kOne(), cudnn_y_desc,
Y->data<T>(), cudnn_ygrad_desc, YGrad->data<T>(),
CudnnDataType<T>::kZero(), cudnn_xgrad_desc,
XGrad->mutable_data<T>(context.GetPlace())));
XGrad->mutable_data<T>(context.GetPlace()), MIOPEN_SOFTMAX_ACCURATE,
MIOPEN_SOFTMAX_MODE_INSTANCE));
#else
cudnnTensorDescriptor_t cudnn_y_desc =
yDesc.descriptor<T>(layout, cudnn_tensor_dims);
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/operators/softmax_with_cross_entropy_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,11 @@ template <typename T>
static void SoftmaxWithCrossEntropyFusedKernel(
const T* logits_data, const T* labels_data, T* softmax_data, T* loss_data,
int64_t n, int64_t d, int axis_dim, gpuStream_t stream) {
#ifdef __HIPCC__
constexpr int kMaxBlockDim = 256;
#else
constexpr int kMaxBlockDim = 512;
#endif
int64_t block_dim = axis_dim >= kMaxBlockDim
? kMaxBlockDim
: (1 << static_cast<int>(std::log2(axis_dim)));
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/platform/dynload/miopen.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
__macro(miopenPoolingForward); \
__macro(miopenPoolingBackward); \
__macro(miopenSoftmaxBackward); \
__macro(miopenSoftmaxBackward_V2); \
__macro(miopenSoftmaxForward); \
__macro(miopenSoftmaxForward_V2); \
__macro(miopenCreateDropoutDescriptor); \
__macro(miopenDestroyDropoutDescriptor); \
__macro(miopenRestoreDropoutDescriptor); \
Expand Down

0 comments on commit e12798d

Please sign in to comment.