Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aarch64] patch mkldnn acl inner product to accelerate torch.compile() for bert #1631

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions aarch64_linux/aarch64_wheel_ci_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def parse_arguments():
# work around to fix Raspberry pie crash
print("Applying mkl-dnn patch to fix readdir crash")
os.system("cd /pytorch/third_party/ideep/mkl-dnn && patch -p1 < /builder/mkldnn_fix/aarch64-fix-readdir-crash.patch")
# patch acl inner product to accelerate torch.compile() path
print("Applying mkl-dnn patch to acl inner product")
os.system("cd /pytorch/third_party/ideep/mkl-dnn && patch -p1 < /builder/mkldnn_fix/cpu-aarch64-add-sbgemm-fp32-input-and-bf16-weights-ip.patch")
os.system(f"cd /pytorch; {build_vars} python3 setup.py bdist_wheel")
pytorch_wheel_name = complete_wheel("pytorch")
print(f"Build Compelete. Created {pytorch_wheel_name}..")
1 change: 1 addition & 0 deletions aarch64_linux/build_aarch64_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,7 @@ def start_build(host: RemoteHost, *,
print("build pytorch with mkldnn+acl backend")
build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON"
host.run_cmd(f"cd $HOME && git clone https://github.com/pytorch/builder.git")
host.run_cmd(f"cd $HOME/pytorch/third_party/ideep/mkl-dnn && patch -p1 < $HOME/builder/mkldnn_fix/cpu-aarch64-add-sbgemm-fp32-input-and-bf16-weights-ip.patch")
host.run_cmd(f"cd $HOME/pytorch && export ACL_ROOT_DIR=$HOME/ComputeLibrary && {build_vars} python3 setup.py bdist_wheel{build_opts}")
print('Repair the wheel')
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
cpu: aarch64: add sbgemm (fp32 input and bf16 weights) inner
product op

---
src/cpu/aarch64/acl_inner_product.hpp | 8 ++++++--
src/cpu/cpu_inner_product_list.cpp | 4 ++++
2 files changed, 10 insertions(+), 2 deletions(-)

diff --git a/src/cpu/aarch64/acl_inner_product.hpp b/src/cpu/aarch64/acl_inner_product.hpp
index a2be164f0..eca56b289 100644
--- a/src/cpu/aarch64/acl_inner_product.hpp
+++ b/src/cpu/aarch64/acl_inner_product.hpp
@@ -99,9 +99,13 @@ struct acl_inner_product_fwd_t : public primitive_t {
const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef)
&& attr()->has_default_values(
primitive_attr_t::skip_mask_t::post_ops, f32);
+ const bool is_fp32_bf16_ok
+ = expect_data_types(f32, bf16, f32, f32, undef)
+ && attr()->has_default_values(
+ primitive_attr_t::skip_mask_t::post_ops, f32);
const bool ok = is_fwd() && !has_zero_dim_memory()
- && utils::one_of(true, is_fp16_ok, is_fp32_ok)
- && weights_md_.format_kind == format_kind::any
+ && utils::one_of(
+ true, is_fp16_ok, is_fp32_ok, is_fp32_bf16_ok)
&& set_default_params() == status::success;

if (!ok) return status::unimplemented;
diff --git a/src/cpu/cpu_inner_product_list.cpp b/src/cpu/cpu_inner_product_list.cpp
index fdd7b1776..5a3dc1ea7 100644
--- a/src/cpu/cpu_inner_product_list.cpp
+++ b/src/cpu/cpu_inner_product_list.cpp
@@ -83,6 +83,10 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
CPU_INSTANCE(ref_inner_product_fwd_t)
nullptr,
}},
+ {{forward, f32, bf16, f32}, {
+ CPU_INSTANCE_AARCH64_ACL(acl_inner_product_fwd_t)
+ nullptr,
+ }},
{{backward_data, f32, f32, f32}, REG_BWD_PK({
CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t<avx512_core_amx>) // bf32
CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t<avx512_core>)
--
2.34.1