Skip to content

Commit

Permalink
[CherryPick][Intrinsic] lower_bound and upper_bound for binary search…
Browse files Browse the repository at this point in the history
… in Sparse TIR. (apache#483) (#4)

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* codegen-rule

* upd

* upd

* test

* upd

* fix

* two arguments

Co-authored-by: Zihao Ye <[email protected]>
  • Loading branch information
MasterJH5574 and yzh119 committed Nov 5, 2021
1 parent 7d90bc5 commit a701744
Show file tree
Hide file tree
Showing 10 changed files with 247 additions and 1 deletion.
10 changes: 10 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,16 @@ TVM_DLL const Op& tvm_warp_shuffle_up();
TVM_DLL const Op& tvm_warp_shuffle_down();
TVM_DLL const Op& tvm_warp_activemask();

/*!
* \brief Lower bound function for binary search.
*/
TVM_DLL const Op& tvm_lower_bound();

/*!
* \brief Upper bound function for binary search.
*/
TVM_DLL const Op& tvm_upper_bound();

/*!
* \brief Initialize the global barrier.
* Call this at beginning of kernel that need global barrier.
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/script/tir/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,16 @@ def max_value(dtype, span):
return tvm.tir.max_value(dtype, span)


@register
def lower_bound(arr, val, l, r, span):
return tvm.tir.lower_bound(arr, val, l, r, span)


@register
def upper_bound(arr, val, l, r, span):
return tvm.tir.upper_bound(arr, val, l, r, span)


@register
def floordiv(x, y, span):
return tvm.tir.floordiv(x, y, span)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from .function import PrimFunc

from .op import call_packed, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace, lower_bound, upper_bound
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
from .op import sin, sinh, asin, asinh
from .op import cos, cosh, acos, acosh
Expand Down
56 changes: 56 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,62 @@ def ldexp(x1, x2):
return call_intrin(x1.dtype, "tir.ldexp", x1, x2) # type: ignore


def lower_bound(arr, val, l, r, span=None):
"""Return the position to the first element in the arr[l:r] that is no less than val.
Parameters
----------
arr : Var
Pointer to the 1D buffer to apply binary search on.
val : PrimExpr
Value of the lower bound to search for in the buffer.
l : PrimExpr
Start position to search for in the buffer.
r : PrimExpr
End position to search for in the buffer.
span : Optional[Span]
The location of this expression in the source code.
Returns
-------
PrimExpr
The index of element in arr[l:r] that is no less then given value.
"""
return _ffi_api.lower_bound(arr, val, l, r, span) # type: ignore


def upper_bound(arr, val, l, r, span=None):
"""Return the position the first element in the arr that is greater than val.
Parameters
----------
arr : Var
Pointer to the 1D buffer to apply binary search on.
val : PrimExpr
Value of the upper bound to search for in the buffer.
l : PrimExpr
Start position to search for in the buffer.
r : PrimExpr
End position to search for in the buffer.
span : Optional[Span]
The location of this expression in the source code.
Returns
-------
PrimExpr
The index of element in arr[l:r] that is no less then given value.
"""
return _ffi_api.upper_bound(arr, val, l, r, span) # type: ignore


def isnan(x, span=None):
"""Check if input value is Nan.
Expand Down
29 changes: 29 additions & 0 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <vector>

#include "literal/cuda_half_t.h"
#include "literal/cuda_binary_search.h"

namespace tvm {
namespace codegen {
Expand Down Expand Up @@ -132,6 +133,10 @@ std::string CodeGenCUDA::Finish() {
decl_stream << "#include <mma.h>\n";
}

if (need_binary_search_) {
decl_stream << _cuda_binary_search_def;
}

decl_stream << "\n#ifdef _WIN32\n";
decl_stream << " using uint = unsigned int;\n";
decl_stream << " using uchar = unsigned char;\n";
Expand Down Expand Up @@ -723,6 +728,30 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
this->PrintExpr(op->args[i * 2 + 1], os);
os << "]" << ((i < 3) ? ", " : ")");
}
} else if (op->op.same_as(builtin::tvm_lower_bound())) {
need_binary_search_ = true;
os << "__lower_bound(";
ICHECK_EQ(op->args.size(), 4U);
this->PrintExpr(op->args[0], os);
os << ", ";
this->PrintExpr(op->args[1], os);
os << ", ";
this->PrintExpr(op->args[2], os);
os << ", ";
this->PrintExpr(op->args[3], os);
os << ")";
} else if (op->op.same_as(builtin::tvm_upper_bound())) {
need_binary_search_ = true;
os << "__upper_bound(";
ICHECK_EQ(op->args.size(), 4U);
this->PrintExpr(op->args[0], os);
os << ", ";
this->PrintExpr(op->args[1], os);
os << ", ";
this->PrintExpr(op->args[2], os);
os << ", ";
this->PrintExpr(op->args[3], os);
os << ")";
} else {
CodeGenC::VisitExpr_(op, os);
}
Expand Down
2 changes: 2 additions & 0 deletions src/target/source/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class CodeGenCUDA final : public CodeGenC {
bool need_math_constants_h_{false};
// whether need mma.h
bool need_mma_h_{false};
// whether need binary search
bool need_binary_search_{false};
// Op attribute map
OpAttrMap<bool> op_need_warp_shuffle_ = Op::GetAttrMap<bool>("cuda.need_warp_shuffle");

Expand Down
69 changes: 69 additions & 0 deletions src/target/source/literal/cuda_binary_search.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file cuda_binary_search.h
* \brief Binary search function definition for cuda codegen.
*/
#ifndef TVM_TARGET_SOURCE_LITERAL_CUDA_BINARY_SEARCH_H_
#define TVM_TARGET_SOURCE_LITERAL_CUDA_BINARY_SEARCH_H_

static constexpr const char* _cuda_binary_search_def = R"(
template <typename DType>
__forceinline__ __device__ int32_t __lower_bound(
const DType* __restrict__ arr,
DType val,
int32_t l,
int32_t r) {
int32_t low = l - 1, high = r;
/* loop invariant: low < mid < high, arr[low] < val, arr[high] >= val */
while (low + 1 < high) {
int32_t mid = (low + high) >> 1;
if (arr[mid] < val) {
low = mid;
} else {
high = mid;
}
}
// high = low + 1, arr[low] < val, arr[high] >= val
return high;
}
template <typename DType>
__forceinline__ __device__ int32_t __upper_bound(
const DType* __restrict__ arr,
DType val,
int32_t l,
int32_t r) {
int32_t low = l - 1, high = r;
/* loop invariant: low < mid < high, arr[low] < val, arr[high] > val */
while (low + 1 < high) {
int32_t mid = (low + high) >> 1;
if (arr[mid] > val) {
high = mid;
} else {
low = mid;
}
}
// high = low + 1, arr[low] <= val, arr[high] > val
return high;
}
)";

#endif // TVM_TARGET_SOURCE_LITERAL_CUDA_BINARY_SEARCH_H_
6 changes: 6 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,12 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_thread_allreduce)
TIR_DEFINE_BUILTIN_FUNC(tvm_load_matrix_sync)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kReadState));

TIR_DEFINE_BUILTIN_FUNC(tvm_lower_bound)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(tvm_upper_bound)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(tvm_mma_sync)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

Expand Down
14 changes: 14 additions & 0 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,16 @@ PrimExpr nearbyint(PrimExpr x, Span span) {

TIR_REGISTER_PURE_UNARY_OP("tir.nearbyint");

// lower_bound
PrimExpr lower_bound(Var arr, PrimExpr val, PrimExpr l, PrimExpr r, Span span) {
return tir::Call({kDLInt, 32, 1}, builtin::tvm_lower_bound(), {arr, val, l, r}, span);
}

// upper_bound
PrimExpr upper_bound(Var arr, PrimExpr val, PrimExpr l, PrimExpr r, Span span) {
return tir::Call({kDLInt, 32, 1}, builtin::tvm_upper_bound(), {arr, val, l, r}, span);
}

// trunc
PrimExpr trunc(PrimExpr x, Span span) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
Expand Down Expand Up @@ -918,6 +928,10 @@ TVM_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc);

TVM_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast);

TVM_REGISTER_GLOBAL("tir.lower_bound").set_body_typed(tvm::lower_bound);

TVM_REGISTER_GLOBAL("tir.upper_bound").set_body_typed(tvm::upper_bound);

// operator overloading, smarter than make
#define REGISTER_MAKE_BINARY_OP(Node, Func) \
TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { \
Expand Down
50 changes: 50 additions & 0 deletions tests/python/unittest/test_tir_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,55 @@ def test_fma():
assert mod["test_tir_fma"].body.body.value.op.name == "tir.call_llvm_pure_intrin"


@tvm.script.tir
def binary_search(a: ty.handle, b: ty.handle, c: ty.handle, d: ty.handle) -> None:
n = tir.var('int32')
m = tir.var('int32')
A = tir.match_buffer(a, (n,), dtype='int32')
B = tir.match_buffer(b, (m,), dtype='int32')
C = tir.match_buffer(c, (m,), dtype='int32')
D = tir.match_buffer(d, (m,), dtype='int32')
with tir.block([m], 'search') as [vi]:
tir.reads([A[0:n], B[vi]])
tir.writes([C[vi], D[vi]])
C[vi] = tir.lower_bound(A.data, B[vi], 0, n)
D[vi] = tir.upper_bound(A.data, B[vi], 0, n)


def test_binary_search():
sch = tir.Schedule(binary_search)
b = sch.get_block('search')
i, = sch.get_loops(b)
io, ii = sch.split(i, [1, None])
sch.bind(io, 'threadIdx.x')
sch.bind(ii, 'blockIdx.x')
f = tvm.build(sch.mod['main'], target='cuda')
# print(f.imported_modules[0].get_source())

x = np.arange(-128, 128).astype(np.int32)
y = np.random.randint(-200, 200, size=1024).astype(np.int32)
a = np.zeros((1024,)).astype(np.int32)
b = np.zeros((1024,)).astype(np.int32)

# numpy results
np_a = np.searchsorted(x, y, side='left').astype(np.int32)
np_b = np.searchsorted(x, y, side='right').astype(np.int32)

# tvm results
dev = tvm.cuda(0)
x_array = tvm.nd.array(x, device=dev)
y_array = tvm.nd.array(y, device=dev)
a_array = tvm.nd.array(a, device=dev)
b_array = tvm.nd.array(b, device=dev)
f(x_array, y_array, a_array, b_array)
tvm_a = a_array.numpy()
tvm_b = b_array.numpy()

# verify result
tvm.testing.assert_allclose(np_a, tvm_a)
tvm.testing.assert_allclose(np_b, tvm_b)


if __name__ == "__main__":
test_nearbyint()
test_unary_intrin()
Expand All @@ -261,3 +310,4 @@ def test_fma():
test_ldexp()
test_clz()
test_fma()
test_binary_search()

0 comments on commit a701744

Please sign in to comment.