Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reusing of softmax mkldnn primitives #10576

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 54 additions & 19 deletions paddle/fluid/operators/softmax_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,25 +53,60 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
"Softmax input and output dimensions should match");
// Same memory descriptor to be used for input and output
memory::dims softmax_tz = {src_tz[0], src_tz[1]};
// Currently only supports NC data format
// TODO(jczaja-intel): support more formats
auto softmax_md =
MKLDNNMemDesc({softmax_tz}, memory::f32, memory::format::nc);
// Normalization is made after innermost dimension eg. C out of NC
auto softmax_desc = softmax_forward::desc(prop_kind::forward_scoring,
softmax_md, 1 /*dim: C*/);
// create memory primitives
auto softmax_src_memory =
memory({softmax_md, mkldnn_engine},
static_cast<void*>(const_cast<T*>(input_data)));
auto softmax_dst_memory =
memory({softmax_md, mkldnn_engine},
static_cast<void*>(const_cast<T*>(output_data)));
auto softmax_prim_desc =
softmax_forward::primitive_desc(softmax_desc, mkldnn_engine);
auto softmax = softmax_forward(softmax_prim_desc, softmax_src_memory,
softmax_dst_memory);
std::vector<primitive> pipeline{softmax};
// Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Each MKLDNN operator may have diffrent hashing function
auto gethash = [](memory::dims& operand_dims) {
return std::string(std::to_string(operand_dims[0]) + "-" +
std::to_string(operand_dims[1]));
};
const std::string key = gethash(softmax_tz);
const std::string key_softmax_p = key + "@softmax_p";
const std::string key_softmax_src_mem_p = key + "@softmax_src_mem_p";
const std::string key_softmax_dst_mem_p = key + "@softmax_dst_mem_p";

std::shared_ptr<void> softmax_p = dev_ctx.GetBlob(key_softmax_p);
if (softmax_p == nullptr) {
// Currently only NC data format is supported
auto softmax_md =
MKLDNNMemDesc({softmax_tz}, memory::f32, memory::format::nc);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

f32should depends on T, right?
Maybe this should be enhanced, or at least enforce as float.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be enhanced in next PR, since I find:

BDSHYF000120887:operators luotao02$ grep "memory::f32" *.cc
activation_mkldnn_op.cc:                     ? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
activation_mkldnn_op.cc:                     : platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
activation_mkldnn_op.cc:                     ? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
activation_mkldnn_op.cc:                     : platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
pool_mkldnn_op.cc:    auto src_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
pool_mkldnn_op.cc:    auto dst_md = platform::MKLDNNMemDesc(dst_tz, mkldnn::memory::f32,
pool_mkldnn_op.cc:                  {{}, mkldnn::memory::f32, mkldnn::memory::format::nchw},
pool_mkldnn_op.cc:    auto diff_src_md = platform::MKLDNNMemDesc(diff_src_tz, mkldnn::memory::f32,
pool_mkldnn_op.cc:    auto diff_dst_md = platform::MKLDNNMemDesc(diff_dst_tz, mkldnn::memory::f32,
softmax_mkldnn_op.cc:        MKLDNNMemDesc({softmax_tz}, memory::f32, memory::format::nc);

// Normalization is made after innermost dimension eg. C out of NC
auto softmax_desc = softmax_forward::desc(prop_kind::forward_scoring,
softmax_md, 1 /*dim: C*/);
// create memory primitives
auto softmax_src_memory_p = std::make_shared<memory>(
memory::primitive_desc{softmax_md, mkldnn_engine},
static_cast<void*>(const_cast<T*>(input_data)));
dev_ctx.SetBlob(key_softmax_src_mem_p, softmax_src_memory_p);
auto softmax_dst_memory_p = std::make_shared<memory>(
memory::primitive_desc{softmax_md, mkldnn_engine},
static_cast<void*>(output_data));
dev_ctx.SetBlob(key_softmax_dst_mem_p, softmax_dst_memory_p);

auto softmax_forward_pd =
std::make_shared<softmax_forward::primitive_desc>(softmax_desc,
mkldnn_engine);
softmax_p = std::make_shared<softmax_forward>(
*(softmax_forward_pd.get()),
*(static_cast<memory*>(softmax_src_memory_p.get())),
*(static_cast<memory*>(softmax_dst_memory_p.get())));
dev_ctx.SetBlob(key_softmax_p, softmax_p);
} else {
// Primitives already exist
auto src_memory_p = std::static_pointer_cast<memory>(
dev_ctx.GetBlob(key_softmax_src_mem_p));
PADDLE_ENFORCE(src_memory_p != nullptr,
"Fail to find softmax src mem_p in device context");
auto dst_memory_p = std::static_pointer_cast<memory>(
dev_ctx.GetBlob(key_softmax_dst_mem_p));
PADDLE_ENFORCE(dst_memory_p != nullptr,
"Fail to find softmax dst mem_p in device context");
src_memory_p->set_data_handle(
reinterpret_cast<void*>(const_cast<T*>(input_data)));
dst_memory_p->set_data_handle(output_data);
}

std::vector<primitive> pipeline{
*(static_cast<softmax_forward::primitive*>(softmax_p.get()))};
stream(stream::kind::eager).submit(pipeline).wait();

const bool is_test = ctx.Attr<bool>("is_test");
Expand Down