Skip to content

Commit

Permalink
used dynamic strided_slice
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Nov 13, 2020
1 parent 14b024b commit e7dc317
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 12 deletions.
67 changes: 56 additions & 11 deletions python/tvm/topi/cuda/argwhere.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .sort import topk, topk_thrust, argsort, argsort_thrust
from .. import tag
from ..transform import strided_slice, adv_index, squeeze
from ..utils import const_vector

logger = logging.getLogger("topi")

Expand All @@ -36,7 +37,7 @@ def _get_sort_func(mode=0):
if get_global_func("tvm.contrib.thrust.sort", allow_missing=True):
ret = topk_thrust if mode == 0 else argsort_thrust
else:
logger.warn(
logger.warning(
"It's highly recommended to enable thrust library with set(USE_THRUST ON)"
" when compiling argwhere for cuda target. Otherwise, it can result in"
" significant performance degradation or incorrect result"
Expand All @@ -46,6 +47,17 @@ def _get_sort_func(mode=0):
return ret


def _create_end(data, out, end):
ib = tvm.tir.ir_builder.create()
end = tvm.tir.const(end, dtype=out.dtype)
out_ptr = ib.buffer_ptr(out)
bx = te.thread_axis("blockIdx.x")
ib.scope_attr(bx, "thread_extent", 1)
out_ptr[0] = data.shape[0]
out_ptr[1] = end
return ib.get()


def argwhere_1d_ir(condition, out):
"""Low level IR for argwhere 1D
Expand Down Expand Up @@ -125,7 +137,7 @@ def argwhere_1d(output_shape, condition):
tag="argwhere1d_gpu",
)

if out.shape[0] <= 1:
if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1:
return out

sorted_out = _get_sort_func()(
Expand Down Expand Up @@ -218,23 +230,56 @@ def argwhere_2d(output_shape, condition):
tag="argwhere2d_gpu",
)

if out.shape[0] <= 1:
if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)) and int(out.shape[0]) <= 1:
return out

sort_func = _get_sort_func(1)

# sort the output from the least significant to the most significant
# column.
out1 = strided_slice(out, [0, 1], [out.shape[0], 2])
out2 = sort_func(out1, axis=0, dtype="int32")
out3 = squeeze(out2)
out = adv_index(out, [out3])
if isinstance(out.shape[0], (int, tvm.tir.expr.IntImm)):
out1 = strided_slice(out, [0, 1], [out.shape[0], 2])
out2 = sort_func(out1, axis=0, dtype="int32")
out3 = squeeze(out2)
out = adv_index(out, [out3])

out1 = strided_slice(out, [0, 0], [out.shape[0], 1])
out2 = sort_func(out1, axis=0, dtype="int32")
out3 = squeeze(out2)
out1 = strided_slice(out, [0, 0], [out.shape[0], 1])
out2 = sort_func(out1, axis=0, dtype="int32")
out3 = squeeze(out2)

return adv_index(out, [out3])
else:
out_shape = [2]
out_buf = tvm.tir.decl_buffer(out_shape, "int32", "strided_slice_out_buf")
end = te.extern(
[out_shape],
[out],
lambda ins, outs: _create_end(ins[0], outs[0], 2),
dtype="int32",
out_buffers=[out_buf],
name="strided_slice_gpu_end0",
tag="strided_slice_gpu_end0",
)
out1 = strided_slice(out, const_vector([0, 1]), end)
out2 = sort_func(out1, axis=0, dtype="int32")
out3 = squeeze(out2)
out = adv_index(out, [out3])

out_buf = tvm.tir.decl_buffer(out_shape, "int32", "strided_slice_out_buf")
end = te.extern(
[out_shape],
[out],
lambda ins, outs: _create_end(ins[0], outs[0], 1),
dtype="int32",
out_buffers=[out_buf],
name="strided_slice_gpu_end1",
tag="strided_slice_gpu_end1",
)
out1 = strided_slice(out, const_vector([0, 0]), end)
out2 = sort_func(out1, axis=0, dtype="int32")
out3 = squeeze(out2)

return adv_index(out, [out3])
return adv_index(out, [out3])


def argwhere_3d_ir(condition, out):
Expand Down
2 changes: 1 addition & 1 deletion tests/python/topi/python/test_topi_argwhere.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def check_device(device, ctx):

func = tvm.build(sch, [out_shape, condition, out], device, name="argwhere")

# print(func.imported_modules[0].get_source())
print(func.imported_modules[0].get_source())

args = [tvm.nd.array(np_shape, ctx)]
args.append(tvm.nd.array(np_data, ctx))
Expand Down

0 comments on commit e7dc317

Please sign in to comment.