diff --git a/aarch64_linux/aarch64_wheel_ci_build.py b/aarch64_linux/aarch64_wheel_ci_build.py index 3b772847c..2c37de8ef 100755 --- a/aarch64_linux/aarch64_wheel_ci_build.py +++ b/aarch64_linux/aarch64_wheel_ci_build.py @@ -108,6 +108,9 @@ def parse_arguments(): # work around to fix Raspberry pie crash print("Applying mkl-dnn patch to fix readdir crash") os.system("cd /pytorch/third_party/ideep/mkl-dnn && patch -p1 < /builder/mkldnn_fix/aarch64-fix-readdir-crash.patch") + # patch acl inner product to accelerate torch.compile() path + print("Applying mkl-dnn patch to acl inner product") + os.system("cd /pytorch/third_party/ideep/mkl-dnn && patch -p1 < /builder/mkldnn_fix/cpu-aarch64-add-sbgemm-fp32-input-and-bf16-weights-ip.patch") os.system(f"cd /pytorch; {build_vars} python3 setup.py bdist_wheel") pytorch_wheel_name = complete_wheel("pytorch") print(f"Build Compelete. Created {pytorch_wheel_name}..") diff --git a/aarch64_linux/build_aarch64_wheel.py b/aarch64_linux/build_aarch64_wheel.py index 9efd2e6ae..ba9acedc2 100755 --- a/aarch64_linux/build_aarch64_wheel.py +++ b/aarch64_linux/build_aarch64_wheel.py @@ -555,6 +555,7 @@ def start_build(host: RemoteHost, *, print("build pytorch with mkldnn+acl backend") build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON" host.run_cmd(f"cd $HOME && git clone https://github.com/pytorch/builder.git") + host.run_cmd(f"cd $HOME/pytorch/third_party/ideep/mkl-dnn && patch -p1 < $HOME/builder/mkldnn_fix/cpu-aarch64-add-sbgemm-fp32-input-and-bf16-weights-ip.patch") host.run_cmd(f"cd $HOME/pytorch && export ACL_ROOT_DIR=$HOME/ComputeLibrary && {build_vars} python3 setup.py bdist_wheel{build_opts}") print('Repair the wheel') pytorch_wheel_name = host.list_dir("pytorch/dist")[0] diff --git a/mkldnn_fix/cpu-aarch64-add-sbgemm-fp32-input-and-bf16-weights-ip.patch b/mkldnn_fix/cpu-aarch64-add-sbgemm-fp32-input-and-bf16-weights-ip.patch new file mode 100644 index 000000000..bd147b6ee --- /dev/null +++ b/mkldnn_fix/cpu-aarch64-add-sbgemm-fp32-input-and-bf16-weights-ip.patch @@ -0,0 +1,46 @@ +cpu: aarch64: add sbgemm (fp32 input and bf16 weights) inner + product op + +--- + src/cpu/aarch64/acl_inner_product.hpp | 8 ++++++-- + src/cpu/cpu_inner_product_list.cpp | 4 ++++ + 2 files changed, 10 insertions(+), 2 deletions(-) + +diff --git a/src/cpu/aarch64/acl_inner_product.hpp b/src/cpu/aarch64/acl_inner_product.hpp +index a2be164f0..eca56b289 100644 +--- a/src/cpu/aarch64/acl_inner_product.hpp ++++ b/src/cpu/aarch64/acl_inner_product.hpp +@@ -99,9 +99,13 @@ struct acl_inner_product_fwd_t : public primitive_t { + const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef) + && attr()->has_default_values( + primitive_attr_t::skip_mask_t::post_ops, f32); ++ const bool is_fp32_bf16_ok ++ = expect_data_types(f32, bf16, f32, f32, undef) ++ && attr()->has_default_values( ++ primitive_attr_t::skip_mask_t::post_ops, f32); + const bool ok = is_fwd() && !has_zero_dim_memory() +- && utils::one_of(true, is_fp16_ok, is_fp32_ok) +- && weights_md_.format_kind == format_kind::any ++ && utils::one_of( ++ true, is_fp16_ok, is_fp32_ok, is_fp32_bf16_ok) + && set_default_params() == status::success; + + if (!ok) return status::unimplemented; +diff --git a/src/cpu/cpu_inner_product_list.cpp b/src/cpu/cpu_inner_product_list.cpp +index fdd7b1776..5a3dc1ea7 100644 +--- a/src/cpu/cpu_inner_product_list.cpp ++++ b/src/cpu/cpu_inner_product_list.cpp +@@ -83,6 +83,10 @@ const std::map> &impl_list_map() + CPU_INSTANCE(ref_inner_product_fwd_t) + nullptr, + }}, ++ {{forward, f32, bf16, f32}, { ++ CPU_INSTANCE_AARCH64_ACL(acl_inner_product_fwd_t) ++ nullptr, ++ }}, + {{backward_data, f32, f32, f32}, REG_BWD_PK({ + CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t) // bf32 + CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t) +-- +2.34.1 +