From be3975ab5833f205af1f0a542c92aedd391673e6 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 30 Oct 2021 06:02:16 +0800 Subject: [PATCH] [BugFix] Fix binary search & SpIterVar (#7) --- include/tvm/tir/op.h | 24 ++++++++++++++++++++++++ src/tir/op/op.cc | 4 ++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 9cf7d0a3cd1f..e800bb51748c 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -793,6 +793,30 @@ TVM_DLL PrimExpr round(PrimExpr x, Span span = Span()); */ TVM_DLL PrimExpr nearbyint(PrimExpr x, Span span = Span()); +/*! + * \brief Lower bound function for binary search + * \param arr The buffer variable of the array to be looked up in + * \param val The value to be looked up in the array + * \param l The left boundary of the look-up range (inclusive) + * \param r The right boundary of the look-up range (exclusive) + * \param span The location of this operation in the source + * \return The look-up result + */ +TVM_DLL PrimExpr lower_bound(tir::Var arr, PrimExpr val, PrimExpr l, PrimExpr r, + Span span = Span()); + +/*! + * \brief Upper bound function for binary search + * \param arr The buffer variable of the array to be looked up in + * \param val The value to be looked up in the array + * \param l The left boundary of the look-up range (inclusive) + * \param r The right boundary of the look-up range (exclusive) + * \param span The location of this operation in the source + * \return The look-up result + */ +TVM_DLL PrimExpr upper_bound(tir::Var arr, PrimExpr val, PrimExpr l, PrimExpr r, + Span span = Span()); + /*! * \brief Calculate trunc(x) * \param x The input expression. diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 0ce984f5eeec..7f3e3ae5df97 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -805,12 +805,12 @@ PrimExpr nearbyint(PrimExpr x, Span span) { TIR_REGISTER_PURE_UNARY_OP("tir.nearbyint"); // lower_bound -PrimExpr lower_bound(Var arr, PrimExpr val, PrimExpr l, PrimExpr r, Span span) { +PrimExpr lower_bound(tir::Var arr, PrimExpr val, PrimExpr l, PrimExpr r, Span span) { return tir::Call({kDLInt, 32, 1}, builtin::tvm_lower_bound(), {arr, val, l, r}, span); } // upper_bound -PrimExpr upper_bound(Var arr, PrimExpr val, PrimExpr l, PrimExpr r, Span span) { +PrimExpr upper_bound(tir::Var arr, PrimExpr val, PrimExpr l, PrimExpr r, Span span) { return tir::Call({kDLInt, 32, 1}, builtin::tvm_upper_bound(), {arr, val, l, r}, span); }