Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][Arith] Add more strict checking in imm construction and folding. #12515

Merged
merged 9 commits into from
Sep 9, 2022
9 changes: 8 additions & 1 deletion include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,9 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span())
if (t.is_uint()) {
// Use IntImm if it is a small integer
uint64_t uval = static_cast<uint64_t>(value);
if (uval <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
if (value < static_cast<ValueType>(0)) {
LOG(FATAL) << "cannot make uint from negative value " << value;
} else if (uval <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
return IntImm(t, static_cast<int64_t>(value), span);
} else {
uint64_t mask = (static_cast<uint64_t>(1) << 32U) - 1U;
Expand All @@ -932,6 +934,11 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span())
return PrimExpr();
}

template <>
inline PrimExpr MakeConstScalar(DataType t, bool value, Span span) {
return MakeConstScalar(t, static_cast<int>(value), span);
}

template <typename ValueType, typename>
inline PrimExpr make_const(DataType t, ValueType value, Span span) {
if (t.lanes() == 1) {
Expand Down
14 changes: 10 additions & 4 deletions python/tvm/runtime/object_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,17 @@ def _scalar_type_inference(value):
elif isinstance(value, bool):
dtype = "bool"
elif isinstance(value, float):
# We intentionally convert the float to float32 since it's more common in DL.
dtype = "float32"
# We intentionally prefer convert the float to float32 since it's more common in DL.
if -3.40282347e38 <= value <= 3.40282347e38:
dtype = "float32"
else:
dtype = "float64"
elif isinstance(value, int):
# We intentionally convert the python int to int32 since it's more common in DL.
dtype = "int32"
# We intentionally prefer convert the python int to int32 since it's more common in DL.
if -2147483648 <= value <= 2147483647:
dtype = "int32"
else:
dtype = "int64"
else:
raise NotImplementedError(
"Cannot automatically inference the type." " value={}".format(value)
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/script/tir/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ def truncmod(x, y, span):
return tvm.tir.truncmod(x, y, span)


@register
def truncdiv(x, y, span):
return tvm.tir.truncdiv(x, y, span)


@register
def ceildiv(x, y, span):
return tvm.tir.ceildiv(x, y, span)
Expand Down
112 changes: 100 additions & 12 deletions src/arith/const_fold.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

#include <algorithm>
#include <cmath>
#include <limits>

#include "int_operator.h"

Expand Down Expand Up @@ -73,6 +74,39 @@ inline bool IsIndexType(const DataType& type) {
return type.is_int() && type.lanes() == 1 && (type.bits() == 32 || type.bits() == 64);
}

/*! \brief Helper to get const folding result repr in int64. */
inline int64_t GetFoldResultInt64Repr(int64_t x, const DataType& dtype) {
if (dtype.bits() < 64) {
x &= (1LL << dtype.bits()) - 1;
}
if (dtype.is_int()) {
// get sign extended value of integer with specified bits
int64_t m = 1LL << (dtype.bits() - 1);
x = (x ^ m) - m;
}
return x;
}

/*! \brief Helper to get fp32 const folding result repr in double. */
inline double GetFoldResultDoubleRepr(float x) {
double res = static_cast<double>(x);
if (std::isinf(res) || std::isnan(res)) {
return res;
}
// certain platform (eg, on gcc7-i386) do the folding arithmetic
// on float and write back to double is optimized to double
// precision arithmetic, this is legal and we check the output
// range thus to ensure consistency when the float result is inf.
if (res < std::numeric_limits<float>::lowest()) {
LOG(WARNING) << "underlying float value overflow";
return -std::numeric_limits<double>::infinity();
} else if (res > std::numeric_limits<float>::max()) {
LOG(WARNING) << "underlying float value overflow";
return std::numeric_limits<double>::infinity();
}
return res;
}

#define TVM_ARITH_CONST_PROPAGATION(BODY) \
using tir::FloatImmNode; \
const IntImmNode* pa = a.as<IntImmNode>(); \
Expand All @@ -95,10 +129,22 @@ template <>
inline PrimExpr TryConstFold<tir::Add>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, pa->value + pb->value);
if (pa && pb) {
int64_t res = pa->value + pb->value;
return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
}
if (pa && pa->value == 0) return b;
if (pb && pb->value == 0) return a;
if (fa && fb) return FloatImm(rtype, fa->value + fb->value);
if (fa && fb) {
if (rtype.bits() == 32) {
return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast<float>(fa->value) +
static_cast<float>(fb->value)));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value + fb->value);
} else {
return PrimExpr();
}
}
if (fa && fa->value == 0) return b;
if (fb && fb->value == 0) return a;
});
Expand All @@ -113,9 +159,21 @@ inline PrimExpr TryConstFold<tir::Sub>(PrimExpr a, PrimExpr b) {
<< "Checked failed. Minuend 's value is 0U and it's dtype is uint "
<< "while Subtrahend's dtype is uint; which will cause a negative uint";
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, pa->value - pb->value);
if (pa && pb) {
int64_t res = pa->value - pb->value;
return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
}
if (pb && pb->value == 0) return a;
if (fa && fb) return FloatImm(rtype, fa->value - fb->value);
if (fa && fb) {
if (rtype.bits() == 32) {
return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast<float>(fa->value) -
static_cast<float>(fb->value)));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value - fb->value);
} else {
return PrimExpr();
}
}
if (fb && fb->value == 0) return a;
});
return PrimExpr();
Expand All @@ -125,7 +183,10 @@ template <>
inline PrimExpr TryConstFold<tir::Mul>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, pa->value * pb->value);
if (pa && pb) {
int64_t res = pa->value * pb->value;
return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
}
if (pa) {
if (pa->value == 1) return b;
if (pa->value == 0) return a;
Expand All @@ -134,7 +195,16 @@ inline PrimExpr TryConstFold<tir::Mul>(PrimExpr a, PrimExpr b) {
if (pb->value == 1) return a;
if (pb->value == 0) return b;
}
if (fa && fb) return FloatImm(rtype, fa->value * fb->value);
if (fa && fb) {
if (rtype.bits() == 32) {
return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast<float>(fa->value) *
static_cast<float>(fb->value)));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value * fb->value);
} else {
return PrimExpr();
}
}
if (fa) {
if (fa->value == 1) return b;
if (fa->value == 0) return a;
Expand All @@ -155,7 +225,8 @@ inline PrimExpr TryConstFold<tir::Div>(PrimExpr a, PrimExpr b) {
// due to division and mod can have different modes
// NOTE: this will assumes truc div.
ICHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm(rtype, pa->value / pb->value);
int64_t res = pa->value / pb->value;
return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
}
if (pa) {
if (pa->value == 0) return a;
Expand All @@ -165,7 +236,14 @@ inline PrimExpr TryConstFold<tir::Div>(PrimExpr a, PrimExpr b) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
}
if (fa && fb && fb->value != 0) {
return FloatImm(rtype, fa->value / fb->value);
if (rtype.bits() == 32) {
return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast<float>(fa->value) /
static_cast<float>(fb->value)));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value / fb->value);
} else {
return PrimExpr();
}
}
if (fa && fa->value == 0) return a;
if (fb) {
Expand All @@ -182,7 +260,8 @@ inline PrimExpr TryConstFold<tir::Mod>(PrimExpr a, PrimExpr b) {
const DataType& rtype = a.dtype();
if (pa && pb) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm(rtype, pa->value % pb->value);
int64_t res = pa->value % pb->value;
return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
}
if (pa) {
if (pa->value == 0) return a;
Expand All @@ -201,7 +280,8 @@ inline PrimExpr TryConstFold<tir::FloorDiv>(PrimExpr a, PrimExpr b) {
const DataType& rtype = a.dtype();
if (pa && pb) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm(rtype, arith::floordiv(pa->value, pb->value));
int64_t res = arith::floordiv(pa->value, pb->value);
return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
}
if (pa) {
if (pa->value == 0) return a;
Expand All @@ -211,7 +291,14 @@ inline PrimExpr TryConstFold<tir::FloorDiv>(PrimExpr a, PrimExpr b) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
}
if (fa && fb && fb->value != 0) {
return FloatImm(rtype, std::floor(fa->value / fb->value));
if (rtype.bits() == 32) {
return FloatImm(rtype, GetFoldResultDoubleRepr(std::floor(static_cast<float>(fa->value) /
static_cast<float>(fb->value))));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, std::floor(fa->value / fb->value));
} else {
return PrimExpr();
}
}
if (fa && fa->value == 0) return a;
if (fb) {
Expand All @@ -228,7 +315,8 @@ inline PrimExpr TryConstFold<tir::FloorMod>(PrimExpr a, PrimExpr b) {
const DataType& rtype = a.dtype();
if (pa && pb) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm(rtype, floormod(pa->value, pb->value));
int64_t res = arith::floormod(pa->value, pb->value);
return IntImm(rtype, GetFoldResultInt64Repr(res, rtype));
}
if (pa) {
if (pa->value == 0) return a;
Expand Down
32 changes: 31 additions & 1 deletion src/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
#include <tvm/te/tensor.h>
#include <tvm/tir/expr.h>

#include "../support/scalars.h"

namespace tvm {

PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm(DataType::Int(32), value)) {}
Expand Down Expand Up @@ -76,7 +78,20 @@ IntImm::IntImm(DataType dtype, int64_t value, Span span) {
ICHECK(dtype.is_int() || dtype.is_uint())
<< "ValueError: IntImm supports only int or uint type, but " << dtype << " was supplied.";
if (dtype.is_uint()) {
ICHECK_GE(value, 0U);
ICHECK_GE(value, 0U) << "ValueError: Literal value " << value
<< " is negative for unsigned integer type " << dtype;
if (dtype.bits() < 64) {
ICHECK_LT(value, 1LL << dtype.bits())
<< "ValueError: Literal value " << value << " exceeds maximum of " << dtype;
}
} else if (dtype.bits() == 1) {
// int(1)
ICHECK(value == 0 || value == 1) << "ValueError: " << value << " exceeds range of " << dtype;
} else if (dtype.bits() < 64) {
ICHECK_GE(value, -(1LL << (dtype.bits() - 1)))
<< "ValueError: Literal value " << value << " exceeds minimum of " << dtype;
ICHECK_LT(value, 1LL << (dtype.bits() - 1))
<< "ValueError: Literal value " << value << " exceeds maximum of " << dtype;
}
ObjectPtr<IntImmNode> node = make_object<IntImmNode>();
node->dtype = dtype;
Expand All @@ -103,6 +118,21 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

FloatImm::FloatImm(DataType dtype, double value, Span span) {
ICHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar.";

// check range for float32 and float16 since they have specified range.
if (!std::isinf(value) && !std::isnan(value)) {
if (dtype.bits() == 32) {
ICHECK_GE(value, std::numeric_limits<float>::lowest())
<< "ValueError: Literal value " << value << " exceeds minimum of " << dtype;
ICHECK_LE(value, std::numeric_limits<float>::max())
<< "ValueError: Literal value " << value << " exceeds maximum of " << dtype;
} else if (dtype.is_float16()) {
ICHECK_GE(value, -support::kMaxFloat16)
<< "ValueError: Literal value " << value << " exceeds minimum of " << dtype;
ICHECK_LE(value, support::kMaxFloat16)
<< "ValueError: Literal value " << value << " exceeds maximum of " << dtype;
}
}
ObjectPtr<FloatImmNode> node = make_object<FloatImmNode>();
node->dtype = dtype;
node->value = value;
Expand Down
4 changes: 0 additions & 4 deletions src/support/scalars.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,6 @@ IntImm ValueToIntImm(int64_t value, int width) {
}
}

// 2^15 * (1 + 1023/1024)
// See https://en.wikipedia.org/wiki/Half-precision_floating-point_format
constexpr double kMaxFloat16 = 65504.0;

FloatImm ValueToFloatImm(double value, int width) {
if (width == 16) {
if (!std::isinf(value) && (value < -kMaxFloat16 || value > kMaxFloat16)) {
Expand Down
4 changes: 4 additions & 0 deletions src/support/scalars.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ std::string FloatImmToString(const FloatImm& float_imm);
IntImm ValueToIntImm(int64_t value, int width);
FloatImm ValueToFloatImm(double value, int width);

// 2^15 * (1 + 1023/1024)
// See https://en.wikipedia.org/wiki/Half-precision_floating-point_format
constexpr double kMaxFloat16 = 65504.0;

} // namespace support
} // namespace tvm

Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_op_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def verify(
# Test backwards slicing.
verify((3, 4, 3), [-1, -1, -1], [-5, -5, -5], [-1, -1, -1], (3, 4, 3))
# Test slicing with overlarge indices.
verify((3, 4, 3), [0, 0, 0], [np.iinfo(np.int64).max] * 3, [1, 1, 1], (3, 4, 3))
verify((3, 4, 3), [0, 0, 0], [np.iinfo(np.int32).max] * 3, [1, 1, 1], (3, 4, 3))
# Test slice mode.
verify(
(3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ def test_fuse_dynamic_squeeze_slice_take():

squeeze = relay.op.squeeze(x, axis=[0])
strided_slice = relay.op.strided_slice(
squeeze, begin=[0, 0], end=[15130, 9223372036854775807], strides=[1, 1]
squeeze, begin=[0, 0], end=[15130, 2147483647], strides=[1, 1]
)
take = relay.op.take(strided_slice, take_val, axis=0)

Expand Down
2 changes: 2 additions & 0 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,8 @@ def test_cast_simplify():
ck.verify(tvm.tir.Cast(dtype1, x == x), tvm.tir.const(1, dtype1))
for dtype2 in dtypes:
for i in [0, 1, 2, 3]:
if i > 1 and (dtype1 == "bool" or dtype2 == "bool"):
continue
ck.verify(tvm.tir.Cast(dtype1, tvm.tir.const(i, dtype2)), tvm.tir.const(i, dtype1))


Expand Down
7 changes: 4 additions & 3 deletions tests/python/unittest/test_target_codegen_cuda.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Licensed to the Apache Software Foundation (ASF) under one

# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
Expand Down Expand Up @@ -194,13 +195,13 @@ def check_cuda(n, value, lanes):
fun(a)
np.testing.assert_equal(a.numpy(), np_a)

check_cuda(64, 0xAB, 4)
check_cuda(64, np.int8(0xAB), 4)
check_cuda(64, 0, 4)
check_cuda(64, -3, 4)
check_cuda(64, 0xAB, 3)
check_cuda(64, np.int8(0xAB), 3)
check_cuda(64, 0, 3)
check_cuda(64, -3, 3)
check_cuda(64, 0xAB, 2)
check_cuda(64, np.int8(0xAB), 2)
check_cuda(64, 0, 2)
check_cuda(64, -3, 2)

Expand Down
Loading