Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added exp FP32 FWD/BWD oneDNN kernel and optimized other oneDNN grad kernels #38624

Merged
merged 4 commits into from
Jan 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 70 additions & 25 deletions paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
const auto &mkldnn_engine = dev_ctx.GetEngine();

const auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Output<Tensor>("Out");
auto *out = ctx.Output<Tensor>("Out");

bool is_inplaced = x->IsSharedBufferWith(*y);
bool is_inplaced = x->IsSharedBufferWith(*out);

platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
ctx.GetPlace(), x);
Expand All @@ -94,9 +94,9 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
std::shared_ptr<dnnl::memory> dst_memory_p = nullptr;
if (is_inplaced) {
dst_memory_p = src_memory_p;
y->mutable_data<T>(ctx.GetPlace());
out->mutable_data<T>(ctx.GetPlace());
} else {
dst_memory_p = handler.AcquireDstMemory(y);
dst_memory_p = handler.AcquireDstMemory(out);
}
auto activation_p = handler.AcquireForwardPrimitive();

Expand All @@ -105,8 +105,8 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
astream, {{DNNL_ARG_FROM, *src_memory_p}, {DNNL_ARG_TO, *dst_memory_p}});
astream.wait();

y->set_layout(DataLayout::kMKLDNN);
y->set_format(GetMKLDNNFormat(*dst_memory_p));
out->set_layout(DataLayout::kMKLDNN);
out->set_format(GetMKLDNNFormat(*dst_memory_p));
}

template <typename T>
Expand All @@ -116,15 +116,15 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
const auto &mkldnn_engine = dev_ctx.GetEngine();

const auto *x = ctx.Input<Tensor>("X");
const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
const auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));

platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
ctx.GetPlace(), x, diff_y);
ctx.GetPlace(), x, dout);

auto src_memory_p = handler.AcquireBackwardSrcMemory(x);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(diff_x);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
auto activation_backward_p = handler.AcquireBackwardPrimitive();

auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
Expand All @@ -134,8 +134,37 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
{DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait();

diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format(GetMKLDNNFormat(*diff_src_memory_p));
dx->set_layout(DataLayout::kMKLDNN);
dx->set_format(GetMKLDNNFormat(*diff_src_memory_p));
}

template <typename T>
void eltwise_grad_use_out(const framework::ExecutionContext &ctx,
dnnl::algorithm algorithm) {
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();

const auto *out = ctx.Input<Tensor>("Out");
const auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));

platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine,
ctx.GetPlace(), out, dout);

auto dst_memory_p = handler.AcquireBackwardSrcMemory(out);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx);
auto activation_backward_p = handler.AcquireBackwardPrimitive();

auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
activation_backward_p->execute(astream,
{{DNNL_ARG_DST, *dst_memory_p},
{DNNL_ARG_DIFF_DST, *diff_dst_memory_p},
{DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait();

dx->set_layout(DataLayout::kMKLDNN);
dx->set_format(GetMKLDNNFormat(*diff_src_memory_p));
}

template <typename T, dnnl::algorithm algorithm>
Expand All @@ -152,6 +181,13 @@ struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
}
};

template <typename T, dnnl::algorithm algorithm>
struct MKLDNNActivationGradUseOutFunc : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
eltwise_grad_use_out<T>(ctx, algorithm);
}
};

template <typename T>
struct GeluMKLDNNFunctor : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
Expand Down Expand Up @@ -217,6 +253,9 @@ using AbsMKLDNNFunctor = MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_abs>;
template <typename T>
using EluMKLDNNFunctor = MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_elu>;

template <typename T>
using ExpMKLDNNFunctor = MKLDNNActivationFunc<T, dnnl::algorithm::eltwise_exp>;

template <typename T>
using ReluMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_relu>;
Expand All @@ -234,24 +273,29 @@ using HardSwishMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_hardswish>;

template <typename T>
using SigmoidMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_logistic>;
using SigmoidMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
T, dnnl::algorithm::eltwise_logistic_use_dst_for_bwd>;

template <typename T>
using TanhMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_tanh>;
using TanhMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
T, dnnl::algorithm::eltwise_tanh_use_dst_for_bwd>;

template <typename T>
using SqrtMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_sqrt>;
using SqrtMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
T, dnnl::algorithm::eltwise_sqrt_use_dst_for_bwd>;

template <typename T>
using AbsMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_abs>;

template <typename T>
using EluMKLDNNGradFunctor =
MKLDNNActivationGradFunc<T, dnnl::algorithm::eltwise_elu>;
using EluMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
T, dnnl::algorithm::eltwise_elu_use_dst_for_bwd>;

template <typename T>
using ExpMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
T, dnnl::algorithm::eltwise_exp_use_dst_for_bwd>;

} // namespace operators
} // namespace paddle

Expand Down Expand Up @@ -281,19 +325,20 @@ namespace ops = paddle::operators;
__macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
__macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor); \
__macro(hard_swish, HardSwishMKLDNNFunctor, HardSwishMKLDNNGradFunctor); \
__macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor); \
__macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradUseOutFunctor); \
__macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor); \
__macro(elu, EluMKLDNNFunctor, EluMKLDNNGradFunctor);
__macro(elu, EluMKLDNNFunctor, EluMKLDNNGradUseOutFunctor); \
__macro(exp, ExpMKLDNNFunctor, ExpMKLDNNGradUseOutFunctor);

FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(relu, ReluMKLDNNFunctor,
ReluMKLDNNGradFunctor);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(gelu, GeluMKLDNNFunctor,
GeluMKLDNNGradFunctor);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sigmoid, SigmoidMKLDNNFunctor,
SigmoidMKLDNNGradFunctor);
SigmoidMKLDNNGradUseOutFunctor);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sqrt, SqrtMKLDNNFunctor,
SqrtMKLDNNGradFunctor);
SqrtMKLDNNGradUseOutFunctor);

jakpiase marked this conversation as resolved.
Show resolved Hide resolved
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,16 @@ def set_alpha(self):
self.alpha = 2.5


class TestMKLDNNExpOp(TestActivation):
def setUp(self):
self.op_type = "exp"
x = np.random.random((5, 5, 4)).astype("float32")

self.inputs = {'X': x}
self.attrs = {'use_mkldnn': True}
self.outputs = {'Out': np.exp(x)}


# Check if primitives already exist in backward
class TestMKLDNNAbsPrimitivesAlreadyExist(unittest.TestCase):
def setUp(self):
Expand Down