Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIRV] Minor update to TIR sort to make it work on VK/SPIR-V #7607

Merged
merged 4 commits into from
Mar 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions python/tvm/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ..transform import strided_slice, transpose
from .. import tag
from ..utils import ceil_div, swap
from ..math import cast


def _schedule_sort(outs):
Expand Down Expand Up @@ -142,6 +143,8 @@ def bottom_up_merge(source, dest, source_idx, dest_idx, start, middle, end, even
"""
# pylint: disable=arguments-out-of-order
# initialize iterators
i = ib.allocate("int64", (1,), name="i", scope="local")
j = ib.allocate("int64", (1,), name="j", scope="local")
i[0] = start
j[0] = middle
# set up indexes
Expand Down Expand Up @@ -189,12 +192,13 @@ def assign_j():

def mergesort(source, dest, source_idx, dest_idx, size, width, even):
# calculate the start, mid, and end points of this section
start[0] = width * tid
with ib.if_scope(start[0] < size):
middle[0] = tvm.te.min(start[0] + tvm.tir.indexdiv(width, 2), size)
end[0] = tvm.te.min(start[0] + width, size)
## merge the start->middle and middle->end arrays
bottom_up_merge(source, dest, source_idx, dest_idx, start[0], middle[0], end[0], even)
start = width * tid

with ib.if_scope(start < size):
middle = cast(tvm.te.min(start + tvm.tir.indexdiv(width, 2), size), "int64")
end = cast(tvm.te.min(start + width, size), "int64")
# merge the start->middle and middle->end arrays
bottom_up_merge(source, dest, source_idx, dest_idx, start, middle, end, even)

lim = tvm.tir.generic.cast(
tvm.tir.ceil(tvm.tir.log2(tvm.tir.generic.cast(size, "float64"))), "int64"
Expand All @@ -203,11 +207,6 @@ def mergesort(source, dest, source_idx, dest_idx, size, width, even):
width = 2 << l2_width
# Define and launch the cuda kernel
with ib.new_scope():
i = ib.allocate("int64", (1,), name="i", scope="local")
j = ib.allocate("int64", (1,), name="j", scope="local")
start = ib.allocate("int64", (1,), name="start", scope="local")
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
Expand Down
6 changes: 3 additions & 3 deletions tests/python/topi/python/test_topi_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def check_device(device):
f(tvm_data, tvm_out)
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_sort, rtol=1e0)

for device in ["llvm", "cuda", "opencl"]:
for device in ["llvm", "cuda", "opencl", "vulkan", "nvptx"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @masahi - in the case of vulkan here the tests won't be run until we add vulkan to our CI docker image, right? These are just tested locally?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes vulkan was tested locally, but not on CI.

check_device(device)


Expand Down Expand Up @@ -115,7 +115,7 @@ def check_device(device):
f(tvm_data, tvm_out)
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_indices.astype(data_dtype), rtol=1e0)

for device in ["llvm", "cuda", "opencl"]:
for device in ["llvm", "cuda", "opencl", "vulkan", "nvptx"]:
check_device(device)


Expand Down Expand Up @@ -167,7 +167,7 @@ def check_device(device):
else:
tvm.testing.assert_allclose(tvm_res[0].asnumpy(), np_indices)

for device in ["llvm", "cuda", "opencl"]:
for device in ["llvm", "cuda", "opencl", "vulkan", "nvptx"]:
check_device(device)


Expand Down