Skip to content

Commit

Permalink
use kwarg arch from ipex (#105) (#107)
Browse files Browse the repository at this point in the history
* use cc from ipex

* refine code

(cherry picked from commit c3ecd6d)
  • Loading branch information
guangyey authored Sep 26, 2023
1 parent 7fc9f9b commit 69998dd
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,10 +420,12 @@ def get_kernel_bin(self):
return "spvbin"

def get_architecture_descriptor(self, **kwargs):
dev_props = self.driver.utils.get_device_properties(torch.xpu.device(torch.xpu.current_device()).sycl_device) # noqa: E501
max_work_group_size = dev_props['max_work_group_size']
max_num_sub_groups = dev_props['max_num_sub_groups']
sub_group_sizes = dev_props['sub_group_sizes']
arch = kwargs.get("arch", None)
if arch is None:
arch = self.get_device_properties(self.get_current_device())
max_work_group_size = arch['max_work_group_size']
max_num_sub_groups = arch['max_num_sub_groups']
sub_group_sizes = arch['sub_group_sizes']
# TODO: chose a reasonable subgroup size
threads_per_warp = 32
assert threads_per_warp in sub_group_sizes, "Current platform does not support threads_per_warp to be 32" # noqa: E501
Expand Down

0 comments on commit 69998dd

Please sign in to comment.