Skip to content

Commit

Permalink
aarch64: apply the cherrypicked onednn PR-1768 (#1717)
Browse files Browse the repository at this point in the history
This is to improve the torch.compile() perf by 5.8x
on AWS Graviton3 instances. This patching is required
till PyTorch oneDNN is upgraded to v3.4.
  • Loading branch information
snadampal authored Mar 12, 2024
1 parent c084122 commit ab5fc90
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
3 changes: 3 additions & 0 deletions aarch64_linux/aarch64_wheel_ci_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ def parse_arguments():
with open("/builder/mkldnn_fix/fix-xbyak-failure.patch") as f:
check_call(["patch", "-p1"], stdin=f, cwd="/pytorch/third_party/ideep/mkl-dnn")

print("Applying mkl-dnn patch to improve torch.compile() perf")
os.system("cd /pytorch/third_party/ideep/mkl-dnn && patch -p1 < /builder/mkldnn_fix/onednn-pr1768-aarch64-add-acl-sbgemm-inner-product-primitive.patch") # noqa: E501

os.system(f"cd /pytorch; {build_vars} python3 setup.py bdist_wheel")
pytorch_wheel_name = complete_wheel("pytorch")
print(f"Build Compelete. Created {pytorch_wheel_name}..")
1 change: 1 addition & 0 deletions aarch64_linux/build_aarch64_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ def start_build(host: RemoteHost, *,
build_vars += " USE_MKLDNN=ON USE_MKLDNN_ACL=ON"
host.run_cmd("cd $HOME && git clone https://github.com/pytorch/builder.git")
host.run_cmd("cd $HOME/pytorch/third_party/ideep/mkl-dnn && patch -p1 < $HOME/builder/mkldnn_fix/fix-xbyak-failure.patch") # noqa: E501
host.run_cmd("cd $HOME/pytorch/third_party/ideep/mkl-dnn && patch -p1 < $HOME/builder/mkldnn_fix/onednn-pr1768-aarch64-add-acl-sbgemm-inner-product-primitive.patch") # noqa: E501
host.run_cmd(f"cd $HOME/pytorch && export ACL_ROOT_DIR=$HOME/ComputeLibrary && {build_vars} python3 setup.py bdist_wheel{build_opts}") # noqa: E501
print('Repair the wheel')
pytorch_wheel_name = host.list_dir("pytorch/dist")[0]
Expand Down

0 comments on commit ab5fc90

Please sign in to comment.