From 8305ba378708cdde9529520f7579a1d8823f4e10 Mon Sep 17 00:00:00 2001 From: TTerror Date: Fri, 3 Sep 2021 14:38:16 +0800 Subject: [PATCH] fix bn_infer and optimize momentum for kunlun (#35250) --- cmake/external/xpu.cmake | 2 +- paddle/fluid/operators/batch_norm_op_xpu.cc | 31 +++++++++---------- .../operators/optimizers/momentum_op_xpu.cc | 8 ++--- paddle/fluid/platform/xpu/xpu2_op_list.h | 4 +++ 4 files changed, 24 insertions(+), 21 deletions(-) diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index 2fc5bf7954ce0..cb946fb85c09a 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -35,7 +35,7 @@ ELSE () ENDIF() SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") -SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210826") +SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210830") SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_XCCL_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) diff --git a/paddle/fluid/operators/batch_norm_op_xpu.cc b/paddle/fluid/operators/batch_norm_op_xpu.cc index 526fc7364cdd8..8499d1cdcd646 100644 --- a/paddle/fluid/operators/batch_norm_op_xpu.cc +++ b/paddle/fluid/operators/batch_norm_op_xpu.cc @@ -76,26 +76,25 @@ class BatchNormXPUKernel : public framework::OpKernel { W, epsilon, momentum, scale_data, bias_data, saved_mean_data, saved_variance_data, mean_out_data, variance_out_data, true); - PADDLE_ENFORCE_EQ( - r, XPU_SUCCESS, - platform::errors::External("XPU API(batch_norm_train_forward) return " - "wrong value[%d], please check whether " - "Baidu Kunlun Card is properly installed.", - r)); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External( + "The batch_norm XPU API return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); } else { const auto* mean = ctx.Input("Mean"); const auto* variance = ctx.Input("Variance"); - const auto* mean_data = mean->data(); - const auto* variance_data = variance->data(); - int r = xpu::batch_norm_infer_forward( - dev_ctx.x_context(), epsilon, N, C, H, W, x_data, y_data, scale_data, - bias_data, mean_data, variance_data); + const auto* mean_data = mean->data(); + const auto* variance_data = variance->data(); + const auto* x_data = x->data(); + auto* y_data = y->mutable_data(ctx.GetPlace()); + int r = xpu::batch_norm_infer(dev_ctx.x_context(), x_data, y_data, N, C, + H, W, epsilon, scale_data, bias_data, + mean_data, variance_data, true); PADDLE_ENFORCE_EQ( - r, XPU_SUCCESS, - platform::errors::External("XPU API(batch_norm_infer_forward) return " - "wrong value[%d], please check whether " - "Baidu Kunlun Card is properly installed.", - r)); + r, xpu::Error_t::SUCCESS, + platform::errors::External( + "The batch_norm_infer XPU API return wrong value[%d %s]", r, + XPUAPIErrorMsg[r])); } } }; diff --git a/paddle/fluid/operators/optimizers/momentum_op_xpu.cc b/paddle/fluid/operators/optimizers/momentum_op_xpu.cc index 932368e810edd..5624312d9a728 100644 --- a/paddle/fluid/operators/optimizers/momentum_op_xpu.cc +++ b/paddle/fluid/operators/optimizers/momentum_op_xpu.cc @@ -44,10 +44,10 @@ class MomentumOpXPUKernel : public framework::OpKernel { auto grad = ctx.Input("Grad"); auto& dev_ctx = ctx.template device_context(); - int r = xpu::momentum( - dev_ctx.x_context(), param->data(), velocity->data(), - grad->data(), lr, use_nesterov, mu, param_out->numel(), - param_out->data(), velocity_out->data()); + int r = xpu::momentum(dev_ctx.x_context(), param->data(), + velocity->data(), grad->data(), + param_out->data(), velocity_out->data(), + param_out->numel(), lr, use_nesterov, mu); if (r == xpu::Error_t::INVALID_PARAM) { PADDLE_ENFORCE_EQ( r, xpu::Error_t::SUCCESS, diff --git a/paddle/fluid/platform/xpu/xpu2_op_list.h b/paddle/fluid/platform/xpu/xpu2_op_list.h index ab2db1ff3831b..0989f2156877f 100644 --- a/paddle/fluid/platform/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/xpu/xpu2_op_list.h @@ -75,6 +75,10 @@ XPUOpMap& get_kl2_ops() { {"elementwise_min_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, + {"momentum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"batch_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"batch_norm_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, // AddMore };