Skip to content

Commit

Permalink
fix bn_infer and optimize momentum for kunlun (#35250)
Browse files Browse the repository at this point in the history
  • Loading branch information
tangzhiyi11 authored Sep 3, 2021
1 parent 8ba58eb commit 8305ba3
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 21 deletions.
2 changes: 1 addition & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ ELSE ()
ENDIF()

SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210826")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210830")
SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
SET(XPU_XCCL_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE)
Expand Down
31 changes: 15 additions & 16 deletions paddle/fluid/operators/batch_norm_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,26 +76,25 @@ class BatchNormXPUKernel : public framework::OpKernel<T> {
W, epsilon, momentum, scale_data, bias_data,
saved_mean_data, saved_variance_data,
mean_out_data, variance_out_data, true);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(batch_norm_train_forward) return "
"wrong value[%d], please check whether "
"Baidu Kunlun Card is properly installed.",
r));
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External(
"The batch_norm XPU API return wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
} else {
const auto* mean = ctx.Input<Tensor>("Mean");
const auto* variance = ctx.Input<Tensor>("Variance");
const auto* mean_data = mean->data<T>();
const auto* variance_data = variance->data<T>();
int r = xpu::batch_norm_infer_forward(
dev_ctx.x_context(), epsilon, N, C, H, W, x_data, y_data, scale_data,
bias_data, mean_data, variance_data);
const auto* mean_data = mean->data<float>();
const auto* variance_data = variance->data<float>();
const auto* x_data = x->data<float>();
auto* y_data = y->mutable_data<float>(ctx.GetPlace());
int r = xpu::batch_norm_infer(dev_ctx.x_context(), x_data, y_data, N, C,
H, W, epsilon, scale_data, bias_data,
mean_data, variance_data, true);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(batch_norm_infer_forward) return "
"wrong value[%d], please check whether "
"Baidu Kunlun Card is properly installed.",
r));
r, xpu::Error_t::SUCCESS,
platform::errors::External(
"The batch_norm_infer XPU API return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
}
}
};
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/operators/optimizers/momentum_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ class MomentumOpXPUKernel : public framework::OpKernel<T> {
auto grad = ctx.Input<framework::Tensor>("Grad");

auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::momentum(
dev_ctx.x_context(), param->data<float>(), velocity->data<float>(),
grad->data<float>(), lr, use_nesterov, mu, param_out->numel(),
param_out->data<float>(), velocity_out->data<float>());
int r = xpu::momentum(dev_ctx.x_context(), param->data<float>(),
velocity->data<float>(), grad->data<float>(),
param_out->data<float>(), velocity_out->data<float>(),
param_out->numel(), lr, use_nesterov, mu);
if (r == xpu::Error_t::INVALID_PARAM) {
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/platform/xpu/xpu2_op_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ XPUOpMap& get_kl2_ops() {
{"elementwise_min_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"momentum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"batch_norm", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"batch_norm_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
// AddMore
};

Expand Down

0 comments on commit 8305ba3

Please sign in to comment.