Skip to content

Commit

Permalink
Add more strict check in tir imm construction and folding.
Browse files Browse the repository at this point in the history
  • Loading branch information
wrongtest-intellif committed Aug 20, 2022
1 parent 72b0f5e commit 0375ec5
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 23 deletions.
4 changes: 3 additions & 1 deletion include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,9 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span())
if (t.is_uint()) {
// Use IntImm if it is a small integer
uint64_t uval = static_cast<uint64_t>(value);
if (uval <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
if (value < 0) {
LOG(FATAL) << "cannot make uint from negative value " << value;
} else if (uval <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
return IntImm(t, static_cast<int64_t>(value), span);
} else {
uint64_t mask = (static_cast<uint64_t>(1) << 32U) - 1U;
Expand Down
87 changes: 75 additions & 12 deletions src/arith/const_fold.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,19 @@ inline bool IsIndexType(const DataType& type) {
return type.is_int() && type.lanes() == 1 && (type.bits() == 32 || type.bits() == 64);
}

/*! \brief Helper to get const folding result repr in int64. */
inline int64_t GetInt64FoldResultRepr(int64_t x, const DataType& dtype) {
if (dtype.bits() < 64) {
x &= (1LL << dtype.bits()) - 1;
}
if (dtype.is_int()) {
// get sign extended value of integer with specified bits
int64_t m = 1LL << (dtype.bits() - 1);
x = (x ^ m) - m;
}
return x;
}

#define TVM_ARITH_CONST_PROPAGATION(BODY) \
using tir::FloatImmNode; \
const IntImmNode* pa = a.as<IntImmNode>(); \
Expand All @@ -95,10 +108,21 @@ template <>
inline PrimExpr TryConstFold<tir::Add>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, pa->value + pb->value);
if (pa && pb) {
int64_t res = pa->value + pb->value;
return IntImm(rtype, GetInt64FoldResultRepr(res, rtype));
}
if (pa && pa->value == 0) return b;
if (pb && pb->value == 0) return a;
if (fa && fb) return FloatImm(rtype, fa->value + fb->value);
if (fa && fb) {
if (rtype.bits() == 32) {
return FloatImm(rtype, static_cast<float>(fa->value) + static_cast<float>(fb->value));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value + fb->value);
} else {
return PrimExpr();
}
}
if (fa && fa->value == 0) return b;
if (fb && fb->value == 0) return a;
});
Expand All @@ -113,9 +137,20 @@ inline PrimExpr TryConstFold<tir::Sub>(PrimExpr a, PrimExpr b) {
<< "Checked failed. Minuend 's value is 0U and it's dtype is uint "
<< "while Subtrahend's dtype is uint; which will cause a negative uint";
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, pa->value - pb->value);
if (pa && pb) {
int64_t res = pa->value - pb->value;
return IntImm(rtype, GetInt64FoldResultRepr(res, rtype));
}
if (pb && pb->value == 0) return a;
if (fa && fb) return FloatImm(rtype, fa->value - fb->value);
if (fa && fb) {
if (rtype.bits() == 32) {
return FloatImm(rtype, static_cast<float>(fa->value) - static_cast<float>(fb->value));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value - fb->value);
} else {
return PrimExpr();
}
}
if (fb && fb->value == 0) return a;
});
return PrimExpr();
Expand All @@ -125,7 +160,10 @@ template <>
inline PrimExpr TryConstFold<tir::Mul>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, pa->value * pb->value);
if (pa && pb) {
int64_t res = pa->value * pb->value;
return IntImm(rtype, GetInt64FoldResultRepr(res, rtype));
}
if (pa) {
if (pa->value == 1) return b;
if (pa->value == 0) return a;
Expand All @@ -134,7 +172,15 @@ inline PrimExpr TryConstFold<tir::Mul>(PrimExpr a, PrimExpr b) {
if (pb->value == 1) return a;
if (pb->value == 0) return b;
}
if (fa && fb) return FloatImm(rtype, fa->value * fb->value);
if (fa && fb) {
if (rtype.bits() == 32) {
return FloatImm(rtype, static_cast<float>(fa->value) * static_cast<float>(fb->value));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value * fb->value);
} else {
return PrimExpr();
}
}
if (fa) {
if (fa->value == 1) return b;
if (fa->value == 0) return a;
Expand All @@ -155,7 +201,8 @@ inline PrimExpr TryConstFold<tir::Div>(PrimExpr a, PrimExpr b) {
// due to division and mod can have different modes
// NOTE: this will assumes truc div.
ICHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm(rtype, pa->value / pb->value);
int64_t res = pa->value / pb->value;
return IntImm(rtype, GetInt64FoldResultRepr(res, rtype));
}
if (pa) {
if (pa->value == 0) return a;
Expand All @@ -165,7 +212,13 @@ inline PrimExpr TryConstFold<tir::Div>(PrimExpr a, PrimExpr b) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
}
if (fa && fb && fb->value != 0) {
return FloatImm(rtype, fa->value / fb->value);
if (rtype.bits() == 32) {
return FloatImm(rtype, static_cast<float>(fa->value) / static_cast<float>(fb->value));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, fa->value / fb->value);
} else {
return PrimExpr();
}
}
if (fa && fa->value == 0) return a;
if (fb) {
Expand All @@ -182,7 +235,8 @@ inline PrimExpr TryConstFold<tir::Mod>(PrimExpr a, PrimExpr b) {
const DataType& rtype = a.dtype();
if (pa && pb) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm(rtype, pa->value % pb->value);
int64_t res = pa->value % pb->value;
return IntImm(rtype, GetInt64FoldResultRepr(res, rtype));
}
if (pa) {
if (pa->value == 0) return a;
Expand All @@ -201,7 +255,8 @@ inline PrimExpr TryConstFold<tir::FloorDiv>(PrimExpr a, PrimExpr b) {
const DataType& rtype = a.dtype();
if (pa && pb) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm(rtype, arith::floordiv(pa->value, pb->value));
int64_t res = arith::floordiv(pa->value, pb->value);
return IntImm(rtype, GetInt64FoldResultRepr(res, rtype));
}
if (pa) {
if (pa->value == 0) return a;
Expand All @@ -211,7 +266,14 @@ inline PrimExpr TryConstFold<tir::FloorDiv>(PrimExpr a, PrimExpr b) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
}
if (fa && fb && fb->value != 0) {
return FloatImm(rtype, std::floor(fa->value / fb->value));
if (rtype.bits() == 32) {
return FloatImm(rtype,
std::floorf(static_cast<float>(fa->value) / static_cast<float>(fb->value)));
} else if (rtype.bits() == 64) {
return FloatImm(rtype, std::floor(fa->value / fb->value));
} else {
return PrimExpr();
}
}
if (fa && fa->value == 0) return a;
if (fb) {
Expand All @@ -228,7 +290,8 @@ inline PrimExpr TryConstFold<tir::FloorMod>(PrimExpr a, PrimExpr b) {
const DataType& rtype = a.dtype();
if (pa && pb) {
ICHECK_NE(pb->value, 0) << "Divide by zero";
return IntImm(rtype, floormod(pa->value, pb->value));
int64_t res = arith::floormod(pa->value, pb->value);
return IntImm(rtype, GetInt64FoldResultRepr(res, rtype));
}
if (pa) {
if (pa->value == 0) return a;
Expand Down
15 changes: 14 additions & 1 deletion src/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,20 @@ IntImm::IntImm(DataType dtype, int64_t value, Span span) {
ICHECK(dtype.is_int() || dtype.is_uint())
<< "ValueError: IntImm supports only int or uint type, but " << dtype << " was supplied.";
if (dtype.is_uint()) {
ICHECK_GE(value, 0U);
ICHECK_GE(value, 0U) << "ValueError: Literal value " << value
<< " is negative for unsigned integer type " << dtype;
if (dtype.bits() < 64) {
ICHECK_LT(value, 1LL << dtype.bits())
<< "ValueError: Literal value " << value << " exceeds maximum of " << dtype;
}
} else if (dtype.bits() == 1) {
// int(1)
ICHECK(value == 0 || value == 1) << "ValueError: " << value << " exceeds range of " << dtype;
} else if (dtype.bits() < 64) {
ICHECK_GE(value, -(1LL << (dtype.bits() - 1)))
<< "ValueError: Literal value " << value << " exceeds minimum of " << dtype;
ICHECK_LT(value, 1LL << (dtype.bits() - 1))
<< "ValueError: Literal value " << value << " exceeds maximum of " << dtype;
}
ObjectPtr<IntImmNode> node = make_object<IntImmNode>();
node->dtype = dtype;
Expand Down
2 changes: 2 additions & 0 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,8 @@ def test_cast_simplify():
ck.verify(tvm.tir.Cast(dtype1, x == x), tvm.tir.const(1, dtype1))
for dtype2 in dtypes:
for i in [0, 1, 2, 3]:
if i > 1 and (dtype1 == "bool" or dtype2 == "bool"):
continue
ck.verify(tvm.tir.Cast(dtype1, tvm.tir.const(i, dtype2)), tvm.tir.const(i, dtype1))


Expand Down
155 changes: 155 additions & 0 deletions tests/python/unittest/test_tir_imm_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import math
import numpy as np
import tvm
import tvm.testing
import pytest
from tvm import tir
from tvm.script import tir as T
import pytest


@pytest.mark.parametrize(
"dtype, literals",
[
["int8", [-128, 0, 127]],
["uint8", [0, 255]],
["int32", [-2147483648, 2147483647]],
["uint32", [0, 4294967295]],
["int64", [-9223372036854775808, 9223372036854775807]],
["uint64", [0, 9223372036854775807]],
],
)
def test_tir_make_intimm(dtype, literals):
for l in literals:
imm = tir.const(l, dtype)
assert imm.value == l, imm


@pytest.mark.parametrize(
"dtype, literals",
[
["int8", [-129, 128]],
["uint8", [-1, 256]],
["int32", [-2147483650, 2147483648]],
["uint32", [-1, 4294967296]],
["uint64", [-1, 18446744073709551616]],
],
)
def test_tir_invalid_intimm(dtype, literals):
for l in literals:
with pytest.raises(tvm.TVMError):
tir.const(l, dtype)


@pytest.mark.parametrize(
"dtype, literals",
[
[
"int64",
{
-9223372036854775810: 9223372036854775806,
9223372036854775808: -9223372036854775808,
},
],
[
"uint64",
{
9223372036854775807: 9223372036854775807,
18446744073709551615: 18446744073709551615,
},
],
],
)
def test_tir_large_py_int_literals(dtype, literals):
"""
For large uint value, use LargeUIntImm intrin,
For large int value exceed int64_t value range, the value is wrapped back.
"""
for l in literals:
x = tir.const(l, dtype)
if isinstance(x, (tir.IntImm, tir.FloatImm)):
assert x.value == literals[l]
else:
# LargeUIntImm(low32, hi32)
assert (int(x.args[1]) << 32) + int(x.args[0]) == literals[l]


def test_tir_intimm_overflow():
assert int(tir.const(127, "int8") + tir.const(1, "int8")) == -128
assert int(tir.const(127, "int8") + tir.const(2, "int8")) == -127
assert int(tir.const(255, "uint8") + tir.const(1, "uint8")) == 0
assert int(tir.const(2 ** 31 - 1, "int32") + tir.const(1, "int32")) == -(2 ** 31)
assert int(tir.const(2 ** 32 - 1, "uint32") + tir.const(1, "uint32")) == 0
assert int(tir.const(2 ** 63 - 1, "int64") + tir.const(1, "int64")) == -(2 ** 63)
assert int(tir.const(2 ** 32, "uint64") * tir.const(2 ** 32, "uint64")) == 0


def compare_float_value(value, expect):
if math.isfinite(value):
assert value == expect
elif math.isnan(value):
assert math.isnan(expect)
elif math.isinf(value):
assert math.isinf(expect)


@pytest.mark.parametrize("dtype", ["float16", "float32", "float64"])
@pytest.mark.parametrize("literal", [3.14, np.nan, np.inf])
def test_tir_special_floatimms(dtype, literal):
x = tir.const(literal, dtype)
compare_float_value(x.value, literal)


@tvm.testing.requires_llvm()
def test_tir_floatimm_overflow():
# Behavior check: if literal value is out of dtype range, the
# object is still constructed, and eval to infinity.
@T.prim_func
def imm_overflow_fp16() -> T.float16:
T.evaluate(T.ret(T.float16(65536), dtype="float16"))

f = tvm.build(imm_overflow_fp16, target="llvm")
assert math.isinf(f())

@T.prim_func
def imm_overflow_fp32() -> T.float32:
T.evaluate(T.ret(T.float32(3.4028e39), dtype="float32"))

f = tvm.build(imm_overflow_fp32, target="llvm")
assert math.isinf(f())

@T.prim_func
def imm_overflow_fp64() -> T.float64:
T.evaluate(T.ret(T.float64(1.7976e309), dtype="float64"))

f = tvm.build(imm_overflow_fp64, target="llvm")
assert math.isinf(f())

# Behavior check: disable fp16 folding
assert float(tir.const(1.0, "float32") * tir.const(2.0, "float32")) == 2.0
assert not isinstance(tir.const(1.0, "float16") * tir.const(2.0, "float16"), tir.FloatImm)

# Behavior check: folding when fp32 overflow get infinity
x = np.float32(3.4028235e37)
y = np.float32(3.4028235e37)
assert math.isinf(float(tir.const(x, "float32") * tir.const(y, "float32")))


if __name__ == "__main__":
tvm.testing.main()
9 changes: 0 additions & 9 deletions tests/python/unittest/test_tir_transform_narrow_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ def check(m, n, target_bits, target_dtype):
# const shape
# i32 -> i32
check(2, 2, 32, "int32")
# i32 + i32 is not promoted to i64 even if overflow
check(2**16, 2**16, 32, "int32")
# i64 -> i32
check(const(2, dtype="int64"), const(2, dtype="int64"), 32, "int32")
check(const(2**16, dtype="int64"), const(2**16, dtype="int64"), 32, "int64")
Expand Down Expand Up @@ -100,12 +98,6 @@ def check(m, n, target_bits, target_dtype):

# i32 -> i32
check(2, 32, target_bits=32, target_dtype="int32")
check(
2**30,
32, # i32 + i32 is not promoted to i64 even in the case of overflow
target_bits=32,
target_dtype="int32",
)
# i64 -> i32
check(const(2, dtype="int64"), const(32, dtype="int64"), target_bits=32, target_dtype="int32")
check(
Expand Down Expand Up @@ -162,7 +154,6 @@ def check(m, lanes, target_bits, target_dtype):

# i32 -> i32
check(const(2**10, dtype="int32"), 2, target_bits=32, target_dtype="int32")
check(const(2**32, dtype="int32"), 2, target_bits=32, target_dtype="int32")
# i64 -> i32
check(const(2**10, dtype="int64"), 2, target_bits=32, target_dtype="int32")
check(const(2**32, dtype="int64"), 2, target_bits=32, target_dtype="int64")
Expand Down

0 comments on commit 0375ec5

Please sign in to comment.