Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][Arith] Add more strict checking in imm construction and folding. #12515

Merged
merged 9 commits into from
Sep 9, 2022
4 changes: 3 additions & 1 deletion include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,9 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span())
if (t.is_uint()) {
// Use IntImm if it is a small integer
uint64_t uval = static_cast<uint64_t>(value);
if (uval <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
if (static_cast<int64_t>(value) < 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe here we should do value < static_cast<ValueType>(0). Because if value is of uint64_t, it will be regarded as a negative number if it is greater than numeric_limits<int64_t>::max().

LOG(FATAL) << "cannot make uint from negative value " << value;
} else if (uval <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
return IntImm(t, static_cast<int64_t>(value), span);
} else {
uint64_t mask = (static_cast<uint64_t>(1) << 32U) - 1U;
Expand Down
87 changes: 75 additions & 12 deletions src/arith/const_fold.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,19 @@ inline bool IsIndexType(const DataType& type) {
return type.is_int() && type.lanes() == 1 && (type.bits() == 32 || type.bits() == 64);
}

/*! \brief Helper to get const folding result repr in int64. */
inline int64_t GetInt64FoldResultRepr(int64_t x, const DataType& dtype) {
if (dtype.bits() < 64) {
x &= (1LL << dtype.bits()) - 1;
}
if (dtype.is_int()) {
// get sign extended value of integer with specified bits
int64_t m = 1LL << (dtype.bits() - 1);
x = (x ^ m) - m;
}
return x;
}

#define TVM_ARITH_CONST_PROPAGATION(BODY) \
using tir::FloatImmNode; \
const IntImmNode* pa = a.as<IntImmNode>(); \
Expand All @@ -95,10 +108,21 @@ template <>
inline PrimExpr TryConstFold<tir::Add>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, pa->value + pb->value);
if (pa && pb) {
int64_t res = pa->value + pb->value;
return IntImm(rtype, GetInt64FoldResultRepr(res, rtype));
}
if (pa && pa->value == 0) return b;
if (pb && pb->value == 0) return a;
if (fa && fb) return FloatImm(rtype, fa->value + fb->value);
if (fa && fb) {
if (rtype.bits() == 32) {
return FloatImm(rtype, static_cast<float>(fa->value) + static_cast<float>(fb->value));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value + fb->value);
} else {
return PrimExpr();
}
}
if (fa && fa->value == 0) return b;
if (fb && fb->value == 0) return a;
});
Expand All @@ -113,9 +137,20 @@ inline PrimExpr TryConstFold<tir::Sub>(PrimExpr a, PrimExpr b) {
<< "Checked failed. Minuend 's value is 0U and it's dtype is uint "
<< "while Subtrahend's dtype is uint; which will cause a negative uint";
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, pa->value - pb->value);
if (pa && pb) {
int64_t res = pa->value - pb->value;
return IntImm(rtype, GetInt64FoldResultRepr(res, rtype));
}
if (pb && pb->value == 0) return a;
if (fa && fb) return FloatImm(rtype, fa->value - fb->value);
if (fa && fb) {
if (rtype.bits() == 32) {
return FloatImm(rtype, static_cast<float>(fa->value) - static_cast<float>(fb->value));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value - fb->value);
} else {
return PrimExpr();
}
}
if (fb && fb->value == 0) return a;
});
return PrimExpr();
Expand All @@ -125,7 +160,10 @@ template <>
inline PrimExpr TryConstFold<tir::Mul>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, pa->value * pb->value);
if (pa && pb) {
int64_t res = pa->value * pb->value;
return IntImm(rtype, GetInt64FoldResultRepr(res, rtype));
}
if (pa) {
if (pa->value == 1) return b;
if (pa->value == 0) return a;
Expand All @@ -134,7 +172,15 @@ inline PrimExpr TryConstFold<tir::Mul>(PrimExpr a, PrimExpr b) {
if (pb->value == 1) return a;
if (pb->value == 0) return b;
}
if (fa && fb) return FloatImm(rtype, fa->value * fb->value);
if (fa && fb) {
if (rtype.bits() == 32) {
return FloatImm(rtype, static_cast<float>(fa->value) * static_cast<float>(fb->value));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value * fb->value);
} else {
return PrimExpr();
}
}
if (fa) {
if (fa->value == 1) return b;
if (fa->value == 0) return a;
Expand All @@ -155,7 +201,8 @@ inline PrimExpr TryConstFold<tir::Div>(PrimExpr a, PrimExpr b) {
// due to division and mod can have different modes
// NOTE: this will assumes truc div.
ICHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm(rtype, pa->value / pb->value);
int64_t res = pa->value / pb->value;
return IntImm(rtype, GetInt64FoldResultRepr(res, rtype));
}
if (pa) {
if (pa->value == 0) return a;
Expand All @@ -165,7 +212,13 @@ inline PrimExpr TryConstFold<tir::Div>(PrimExpr a, PrimExpr b) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
}
if (fa && fb && fb->value != 0) {
return FloatImm(rtype, fa->value / fb->value);
if (rtype.bits() == 32) {
return FloatImm(rtype, static_cast<float>(fa->value) / static_cast<float>(fb->value));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value / fb->value);
} else {
return PrimExpr();
}
}
if (fa && fa->value == 0) return a;
if (fb) {
Expand All @@ -182,7 +235,8 @@ inline PrimExpr TryConstFold<tir::Mod>(PrimExpr a, PrimExpr b) {
const DataType& rtype = a.dtype();
if (pa && pb) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm(rtype, pa->value % pb->value);
int64_t res = pa->value % pb->value;
return IntImm(rtype, GetInt64FoldResultRepr(res, rtype));
}
if (pa) {
if (pa->value == 0) return a;
Expand All @@ -201,7 +255,8 @@ inline PrimExpr TryConstFold<tir::FloorDiv>(PrimExpr a, PrimExpr b) {
const DataType& rtype = a.dtype();
if (pa && pb) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm(rtype, arith::floordiv(pa->value, pb->value));
int64_t res = arith::floordiv(pa->value, pb->value);
return IntImm(rtype, GetInt64FoldResultRepr(res, rtype));
}
if (pa) {
if (pa->value == 0) return a;
Expand All @@ -211,7 +266,14 @@ inline PrimExpr TryConstFold<tir::FloorDiv>(PrimExpr a, PrimExpr b) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
}
if (fa && fb && fb->value != 0) {
return FloatImm(rtype, std::floor(fa->value / fb->value));
if (rtype.bits() == 32) {
return FloatImm(rtype,
std::floor(static_cast<float>(fa->value) / static_cast<float>(fb->value)));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, std::floor(fa->value / fb->value));
} else {
return PrimExpr();
}
}
if (fa && fa->value == 0) return a;
if (fb) {
Expand All @@ -228,7 +290,8 @@ inline PrimExpr TryConstFold<tir::FloorMod>(PrimExpr a, PrimExpr b) {
const DataType& rtype = a.dtype();
if (pa && pb) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm(rtype, floormod(pa->value, pb->value));
int64_t res = arith::floormod(pa->value, pb->value);
return IntImm(rtype, GetInt64FoldResultRepr(res, rtype));
}
if (pa) {
if (pa->value == 0) return a;
Expand Down
15 changes: 14 additions & 1 deletion src/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,20 @@ IntImm::IntImm(DataType dtype, int64_t value, Span span) {
ICHECK(dtype.is_int() || dtype.is_uint())
<< "ValueError: IntImm supports only int or uint type, but " << dtype << " was supplied.";
if (dtype.is_uint()) {
ICHECK_GE(value, 0U);
ICHECK_GE(value, 0U) << "ValueError: Literal value " << value
<< " is negative for unsigned integer type " << dtype;
if (dtype.bits() < 64) {
ICHECK_LT(value, 1LL << dtype.bits())
<< "ValueError: Literal value " << value << " exceeds maximum of " << dtype;
}
} else if (dtype.bits() == 1) {
// int(1)
ICHECK(value == 0 || value == 1) << "ValueError: " << value << " exceeds range of " << dtype;
} else if (dtype.bits() < 64) {
ICHECK_GE(value, -(1LL << (dtype.bits() - 1)))
<< "ValueError: Literal value " << value << " exceeds minimum of " << dtype;
ICHECK_LT(value, 1LL << (dtype.bits() - 1))
<< "ValueError: Literal value " << value << " exceeds maximum of " << dtype;
}
ObjectPtr<IntImmNode> node = make_object<IntImmNode>();
node->dtype = dtype;
Expand Down
2 changes: 2 additions & 0 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,8 @@ def test_cast_simplify():
ck.verify(tvm.tir.Cast(dtype1, x == x), tvm.tir.const(1, dtype1))
for dtype2 in dtypes:
for i in [0, 1, 2, 3]:
if i > 1 and (dtype1 == "bool" or dtype2 == "bool"):
continue
ck.verify(tvm.tir.Cast(dtype1, tvm.tir.const(i, dtype2)), tvm.tir.const(i, dtype1))


Expand Down
155 changes: 155 additions & 0 deletions tests/python/unittest/test_tir_imm_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import math
import numpy as np
import tvm
import tvm.testing
import pytest
from tvm import tir
from tvm.script import tir as T
import pytest


@pytest.mark.parametrize(
"dtype, literals",
[
["int8", [-128, 0, 127]],
["uint8", [0, 255]],
["int32", [-2147483648, 2147483647]],
["uint32", [0, 4294967295]],
["int64", [-9223372036854775808, 9223372036854775807]],
["uint64", [0, 9223372036854775807]],
],
)
def test_tir_make_intimm(dtype, literals):
for l in literals:
imm = tir.const(l, dtype)
assert imm.value == l, imm


@pytest.mark.parametrize(
"dtype, literals",
[
["int8", [-129, 128]],
["uint8", [-1, 256]],
["int32", [-2147483650, 2147483648]],
["uint32", [-1, 4294967296]],
["uint64", [-1, 18446744073709551616]],
],
)
def test_tir_invalid_intimm(dtype, literals):
for l in literals:
with pytest.raises(tvm.TVMError):
tir.const(l, dtype)


@pytest.mark.parametrize(
"dtype, literals",
[
[
"int64",
{
-9223372036854775810: 9223372036854775806,
9223372036854775808: -9223372036854775808,
},
],
[
"uint64",
{
9223372036854775807: 9223372036854775807,
18446744073709551615: 18446744073709551615,
},
],
],
)
def test_tir_large_py_int_literals(dtype, literals):
"""
For large uint value, use LargeUIntImm intrin,
For large int value exceed int64_t value range, the value is wrapped back.
"""
for l in literals:
x = tir.const(l, dtype)
if isinstance(x, (tir.IntImm, tir.FloatImm)):
assert x.value == literals[l]
else:
# LargeUIntImm(low32, hi32)
assert (int(x.args[1]) << 32) + int(x.args[0]) == literals[l]


def test_tir_intimm_overflow():
assert int(tir.const(127, "int8") + tir.const(1, "int8")) == -128
assert int(tir.const(127, "int8") + tir.const(2, "int8")) == -127
assert int(tir.const(255, "uint8") + tir.const(1, "uint8")) == 0
assert int(tir.const(2**31 - 1, "int32") + tir.const(1, "int32")) == -(2**31)
assert int(tir.const(2**32 - 1, "uint32") + tir.const(1, "uint32")) == 0
assert int(tir.const(2**63 - 1, "int64") + tir.const(1, "int64")) == -(2**63)
assert int(tir.const(2**32, "uint64") * tir.const(2**32, "uint64")) == 0


def compare_float_value(value, expect):
if math.isfinite(value):
assert value == expect
elif math.isnan(value):
assert math.isnan(expect)
elif math.isinf(value):
assert math.isinf(expect)


@pytest.mark.parametrize("dtype", ["float16", "float32", "float64"])
@pytest.mark.parametrize("literal", [3.14, np.nan, np.inf])
def test_tir_special_floatimms(dtype, literal):
x = tir.const(literal, dtype)
compare_float_value(x.value, literal)


@tvm.testing.requires_llvm()
def test_tir_floatimm_overflow():
# Behavior check: if literal value is out of dtype range, the
# object is still constructed, and eval to infinity.
@T.prim_func
def imm_overflow_fp16() -> T.float16:
T.evaluate(T.ret(T.float16(65536), dtype="float16"))
wrongtest-intellif marked this conversation as resolved.
Show resolved Hide resolved

f = tvm.build(imm_overflow_fp16, target="llvm")
assert math.isinf(f())

@T.prim_func
def imm_overflow_fp32() -> T.float32:
T.evaluate(T.ret(T.float32(3.4028e39), dtype="float32"))

f = tvm.build(imm_overflow_fp32, target="llvm")
assert math.isinf(f())

@T.prim_func
def imm_overflow_fp64() -> T.float64:
T.evaluate(T.ret(T.float64(1.7976e309), dtype="float64"))

f = tvm.build(imm_overflow_fp64, target="llvm")
assert math.isinf(f())

# Behavior check: disable fp16 folding
assert float(tir.const(1.0, "float32") * tir.const(2.0, "float32")) == 2.0
assert not isinstance(tir.const(1.0, "float16") * tir.const(2.0, "float16"), tir.FloatImm)

# Behavior check: folding when fp32 overflow get infinity
x = np.float32(3.4028235e37)
y = np.float32(3.4028235e37)
assert math.isinf(float(tir.const(x, "float32") * tir.const(y, "float32")))


if __name__ == "__main__":
tvm.testing.main()
9 changes: 0 additions & 9 deletions tests/python/unittest/test_tir_transform_narrow_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ def check(m, n, target_bits, target_dtype):
# const shape
# i32 -> i32
check(2, 2, 32, "int32")
# i32 + i32 is not promoted to i64 even if overflow
check(2**16, 2**16, 32, "int32")
# i64 -> i32
check(const(2, dtype="int64"), const(2, dtype="int64"), 32, "int32")
check(const(2**16, dtype="int64"), const(2**16, dtype="int64"), 32, "int64")
Expand Down Expand Up @@ -100,12 +98,6 @@ def check(m, n, target_bits, target_dtype):

# i32 -> i32
check(2, 32, target_bits=32, target_dtype="int32")
check(
2**30,
32, # i32 + i32 is not promoted to i64 even in the case of overflow
target_bits=32,
target_dtype="int32",
)
# i64 -> i32
check(const(2, dtype="int64"), const(32, dtype="int64"), target_bits=32, target_dtype="int32")
check(
Expand Down Expand Up @@ -162,7 +154,6 @@ def check(m, lanes, target_bits, target_dtype):

# i32 -> i32
check(const(2**10, dtype="int32"), 2, target_bits=32, target_dtype="int32")
check(const(2**32, dtype="int32"), 2, target_bits=32, target_dtype="int32")
# i64 -> i32
check(const(2**10, dtype="int64"), 2, target_bits=32, target_dtype="int32")
check(const(2**32, dtype="int64"), 2, target_bits=32, target_dtype="int64")
Expand Down