From 4240d54418e38c8254109e7c3f8b9ebd85ede6c0 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Sat, 20 Aug 2022 22:48:28 +0800 Subject: [PATCH 1/9] Add more strict check in tir imm construction and folding. --- include/tvm/tir/op.h | 4 +- src/arith/const_fold.h | 87 ++++++++-- src/ir/expr.cc | 15 +- .../unittest/test_arith_rewrite_simplify.py | 2 + tests/python/unittest/test_tir_imm_values.py | 155 ++++++++++++++++++ .../test_tir_transform_narrow_datatype.py | 9 - 6 files changed, 249 insertions(+), 23 deletions(-) create mode 100644 tests/python/unittest/test_tir_imm_values.py diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 94603307a7f0..499bfd590c42 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -911,7 +911,9 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) if (t.is_uint()) { // Use IntImm if it is a small integer uint64_t uval = static_cast(value); - if (uval <= static_cast(std::numeric_limits::max())) { + if (value < 0) { + LOG(FATAL) << "cannot make uint from negative value " << value; + } else if (uval <= static_cast(std::numeric_limits::max())) { return IntImm(t, static_cast(value), span); } else { uint64_t mask = (static_cast(1) << 32U) - 1U; diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index 9c3afe41b901..51799278a6db 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -73,6 +73,19 @@ inline bool IsIndexType(const DataType& type) { return type.is_int() && type.lanes() == 1 && (type.bits() == 32 || type.bits() == 64); } +/*! \brief Helper to get const folding result repr in int64. */ +inline int64_t GetInt64FoldResultRepr(int64_t x, const DataType& dtype) { + if (dtype.bits() < 64) { + x &= (1LL << dtype.bits()) - 1; + } + if (dtype.is_int()) { + // get sign extended value of integer with specified bits + int64_t m = 1LL << (dtype.bits() - 1); + x = (x ^ m) - m; + } + return x; +} + #define TVM_ARITH_CONST_PROPAGATION(BODY) \ using tir::FloatImmNode; \ const IntImmNode* pa = a.as(); \ @@ -95,10 +108,21 @@ template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, pa->value + pb->value); + if (pa && pb) { + int64_t res = pa->value + pb->value; + return IntImm(rtype, GetInt64FoldResultRepr(res, rtype)); + } if (pa && pa->value == 0) return b; if (pb && pb->value == 0) return a; - if (fa && fb) return FloatImm(rtype, fa->value + fb->value); + if (fa && fb) { + if (rtype.bits() == 32) { + return FloatImm(rtype, static_cast(fa->value) + static_cast(fb->value)); + } else if (rtype.bits() == 64) { + return FloatImm(rtype, fa->value + fb->value); + } else { + return PrimExpr(); + } + } if (fa && fa->value == 0) return b; if (fb && fb->value == 0) return a; }); @@ -113,9 +137,20 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { << "Checked failed. Minuend 's value is 0U and it's dtype is uint " << "while Subtrahend's dtype is uint; which will cause a negative uint"; const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, pa->value - pb->value); + if (pa && pb) { + int64_t res = pa->value - pb->value; + return IntImm(rtype, GetInt64FoldResultRepr(res, rtype)); + } if (pb && pb->value == 0) return a; - if (fa && fb) return FloatImm(rtype, fa->value - fb->value); + if (fa && fb) { + if (rtype.bits() == 32) { + return FloatImm(rtype, static_cast(fa->value) - static_cast(fb->value)); + } else if (rtype.bits() == 64) { + return FloatImm(rtype, fa->value - fb->value); + } else { + return PrimExpr(); + } + } if (fb && fb->value == 0) return a; }); return PrimExpr(); @@ -125,7 +160,10 @@ template <> inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); - if (pa && pb) return IntImm(rtype, pa->value * pb->value); + if (pa && pb) { + int64_t res = pa->value * pb->value; + return IntImm(rtype, GetInt64FoldResultRepr(res, rtype)); + } if (pa) { if (pa->value == 1) return b; if (pa->value == 0) return a; @@ -134,7 +172,15 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { if (pb->value == 1) return a; if (pb->value == 0) return b; } - if (fa && fb) return FloatImm(rtype, fa->value * fb->value); + if (fa && fb) { + if (rtype.bits() == 32) { + return FloatImm(rtype, static_cast(fa->value) * static_cast(fb->value)); + } else if (rtype.bits() == 64) { + return FloatImm(rtype, fa->value * fb->value); + } else { + return PrimExpr(); + } + } if (fa) { if (fa->value == 1) return b; if (fa->value == 0) return a; @@ -155,7 +201,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { // due to division and mod can have different modes // NOTE: this will assumes truc div. ICHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImm(rtype, pa->value / pb->value); + int64_t res = pa->value / pb->value; + return IntImm(rtype, GetInt64FoldResultRepr(res, rtype)); } if (pa) { if (pa->value == 0) return a; @@ -165,7 +212,13 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { ICHECK_NE(pb->value, 0) << "Divide by zero"; } if (fa && fb && fb->value != 0) { - return FloatImm(rtype, fa->value / fb->value); + if (rtype.bits() == 32) { + return FloatImm(rtype, static_cast(fa->value) / static_cast(fb->value)); + } else if (rtype.bits() == 64) { + return FloatImm(rtype, fa->value / fb->value); + } else { + return PrimExpr(); + } } if (fa && fa->value == 0) return a; if (fb) { @@ -182,7 +235,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) { ICHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImm(rtype, pa->value % pb->value); + int64_t res = pa->value % pb->value; + return IntImm(rtype, GetInt64FoldResultRepr(res, rtype)); } if (pa) { if (pa->value == 0) return a; @@ -201,7 +255,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) { ICHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImm(rtype, arith::floordiv(pa->value, pb->value)); + int64_t res = arith::floordiv(pa->value, pb->value); + return IntImm(rtype, GetInt64FoldResultRepr(res, rtype)); } if (pa) { if (pa->value == 0) return a; @@ -211,7 +266,14 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { ICHECK_NE(pb->value, 0) << "Divide by zero"; } if (fa && fb && fb->value != 0) { - return FloatImm(rtype, std::floor(fa->value / fb->value)); + if (rtype.bits() == 32) { + return FloatImm(rtype, + std::floor(static_cast(fa->value) / static_cast(fb->value))); + } else if (rtype.bits() == 64) { + return FloatImm(rtype, std::floor(fa->value / fb->value)); + } else { + return PrimExpr(); + } } if (fa && fa->value == 0) return a; if (fb) { @@ -228,7 +290,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) { ICHECK_NE(pb->value, 0) << "Divide by zero"; - return IntImm(rtype, floormod(pa->value, pb->value)); + int64_t res = arith::floormod(pa->value, pb->value); + return IntImm(rtype, GetInt64FoldResultRepr(res, rtype)); } if (pa) { if (pa->value == 0) return a; diff --git a/src/ir/expr.cc b/src/ir/expr.cc index d3e23800d6c7..3a09b8458125 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -76,7 +76,20 @@ IntImm::IntImm(DataType dtype, int64_t value, Span span) { ICHECK(dtype.is_int() || dtype.is_uint()) << "ValueError: IntImm supports only int or uint type, but " << dtype << " was supplied."; if (dtype.is_uint()) { - ICHECK_GE(value, 0U); + ICHECK_GE(value, 0U) << "ValueError: Literal value " << value + << " is negative for unsigned integer type " << dtype; + if (dtype.bits() < 64) { + ICHECK_LT(value, 1LL << dtype.bits()) + << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; + } + } else if (dtype.bits() == 1) { + // int(1) + ICHECK(value == 0 || value == 1) << "ValueError: " << value << " exceeds range of " << dtype; + } else if (dtype.bits() < 64) { + ICHECK_GE(value, -(1LL << (dtype.bits() - 1))) + << "ValueError: Literal value " << value << " exceeds minimum of " << dtype; + ICHECK_LT(value, 1LL << (dtype.bits() - 1)) + << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; } ObjectPtr node = make_object(); node->dtype = dtype; diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 82e1372f991e..c880f90ddffe 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -951,6 +951,8 @@ def test_cast_simplify(): ck.verify(tvm.tir.Cast(dtype1, x == x), tvm.tir.const(1, dtype1)) for dtype2 in dtypes: for i in [0, 1, 2, 3]: + if i > 1 and (dtype1 == "bool" or dtype2 == "bool"): + continue ck.verify(tvm.tir.Cast(dtype1, tvm.tir.const(i, dtype2)), tvm.tir.const(i, dtype1)) diff --git a/tests/python/unittest/test_tir_imm_values.py b/tests/python/unittest/test_tir_imm_values.py new file mode 100644 index 000000000000..e368887acc93 --- /dev/null +++ b/tests/python/unittest/test_tir_imm_values.py @@ -0,0 +1,155 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import math +import numpy as np +import tvm +import tvm.testing +import pytest +from tvm import tir +from tvm.script import tir as T +import pytest + + +@pytest.mark.parametrize( + "dtype, literals", + [ + ["int8", [-128, 0, 127]], + ["uint8", [0, 255]], + ["int32", [-2147483648, 2147483647]], + ["uint32", [0, 4294967295]], + ["int64", [-9223372036854775808, 9223372036854775807]], + ["uint64", [0, 9223372036854775807]], + ], +) +def test_tir_make_intimm(dtype, literals): + for l in literals: + imm = tir.const(l, dtype) + assert imm.value == l, imm + + +@pytest.mark.parametrize( + "dtype, literals", + [ + ["int8", [-129, 128]], + ["uint8", [-1, 256]], + ["int32", [-2147483650, 2147483648]], + ["uint32", [-1, 4294967296]], + ["uint64", [-1, 18446744073709551616]], + ], +) +def test_tir_invalid_intimm(dtype, literals): + for l in literals: + with pytest.raises(tvm.TVMError): + tir.const(l, dtype) + + +@pytest.mark.parametrize( + "dtype, literals", + [ + [ + "int64", + { + -9223372036854775810: 9223372036854775806, + 9223372036854775808: -9223372036854775808, + }, + ], + [ + "uint64", + { + 9223372036854775807: 9223372036854775807, + 18446744073709551615: 18446744073709551615, + }, + ], + ], +) +def test_tir_large_py_int_literals(dtype, literals): + """ + For large uint value, use LargeUIntImm intrin, + For large int value exceed int64_t value range, the value is wrapped back. + """ + for l in literals: + x = tir.const(l, dtype) + if isinstance(x, (tir.IntImm, tir.FloatImm)): + assert x.value == literals[l] + else: + # LargeUIntImm(low32, hi32) + assert (int(x.args[1]) << 32) + int(x.args[0]) == literals[l] + + +def test_tir_intimm_overflow(): + assert int(tir.const(127, "int8") + tir.const(1, "int8")) == -128 + assert int(tir.const(127, "int8") + tir.const(2, "int8")) == -127 + assert int(tir.const(255, "uint8") + tir.const(1, "uint8")) == 0 + assert int(tir.const(2**31 - 1, "int32") + tir.const(1, "int32")) == -(2**31) + assert int(tir.const(2**32 - 1, "uint32") + tir.const(1, "uint32")) == 0 + assert int(tir.const(2**63 - 1, "int64") + tir.const(1, "int64")) == -(2**63) + assert int(tir.const(2**32, "uint64") * tir.const(2**32, "uint64")) == 0 + + +def compare_float_value(value, expect): + if math.isfinite(value): + assert value == expect + elif math.isnan(value): + assert math.isnan(expect) + elif math.isinf(value): + assert math.isinf(expect) + + +@pytest.mark.parametrize("dtype", ["float16", "float32", "float64"]) +@pytest.mark.parametrize("literal", [3.14, np.nan, np.inf]) +def test_tir_special_floatimms(dtype, literal): + x = tir.const(literal, dtype) + compare_float_value(x.value, literal) + + +@tvm.testing.requires_llvm() +def test_tir_floatimm_overflow(): + # Behavior check: if literal value is out of dtype range, the + # object is still constructed, and eval to infinity. + @T.prim_func + def imm_overflow_fp16() -> T.float16: + T.evaluate(T.ret(T.float16(65536), dtype="float16")) + + f = tvm.build(imm_overflow_fp16, target="llvm") + assert math.isinf(f()) + + @T.prim_func + def imm_overflow_fp32() -> T.float32: + T.evaluate(T.ret(T.float32(3.4028e39), dtype="float32")) + + f = tvm.build(imm_overflow_fp32, target="llvm") + assert math.isinf(f()) + + @T.prim_func + def imm_overflow_fp64() -> T.float64: + T.evaluate(T.ret(T.float64(1.7976e309), dtype="float64")) + + f = tvm.build(imm_overflow_fp64, target="llvm") + assert math.isinf(f()) + + # Behavior check: disable fp16 folding + assert float(tir.const(1.0, "float32") * tir.const(2.0, "float32")) == 2.0 + assert not isinstance(tir.const(1.0, "float16") * tir.const(2.0, "float16"), tir.FloatImm) + + # Behavior check: folding when fp32 overflow get infinity + x = np.float32(3.4028235e37) + y = np.float32(3.4028235e37) + assert math.isinf(float(tir.const(x, "float32") * tir.const(y, "float32"))) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index d66b4ef5dd5b..20818a5b326a 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -67,8 +67,6 @@ def check(m, n, target_bits, target_dtype): # const shape # i32 -> i32 check(2, 2, 32, "int32") - # i32 + i32 is not promoted to i64 even if overflow - check(2**16, 2**16, 32, "int32") # i64 -> i32 check(const(2, dtype="int64"), const(2, dtype="int64"), 32, "int32") check(const(2**16, dtype="int64"), const(2**16, dtype="int64"), 32, "int64") @@ -100,12 +98,6 @@ def check(m, n, target_bits, target_dtype): # i32 -> i32 check(2, 32, target_bits=32, target_dtype="int32") - check( - 2**30, - 32, # i32 + i32 is not promoted to i64 even in the case of overflow - target_bits=32, - target_dtype="int32", - ) # i64 -> i32 check(const(2, dtype="int64"), const(32, dtype="int64"), target_bits=32, target_dtype="int32") check( @@ -162,7 +154,6 @@ def check(m, lanes, target_bits, target_dtype): # i32 -> i32 check(const(2**10, dtype="int32"), 2, target_bits=32, target_dtype="int32") - check(const(2**32, dtype="int32"), 2, target_bits=32, target_dtype="int32") # i64 -> i32 check(const(2**10, dtype="int64"), 2, target_bits=32, target_dtype="int32") check(const(2**32, dtype="int64"), 2, target_bits=32, target_dtype="int64") From d7c4adcaf18d2169e40ce67f99bec0bc3480df62 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Sun, 21 Aug 2022 00:39:45 +0800 Subject: [PATCH 2/9] fix bool-compare compile error --- include/tvm/tir/op.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 499bfd590c42..258d2511b768 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -911,7 +911,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) if (t.is_uint()) { // Use IntImm if it is a small integer uint64_t uval = static_cast(value); - if (value < 0) { + if (static_cast(value) < 0) { LOG(FATAL) << "cannot make uint from negative value " << value; } else if (uval <= static_cast(std::numeric_limits::max())) { return IntImm(t, static_cast(value), span); From 3cfaefdeb78704b01e2639be819141d5d3263121 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Mon, 22 Aug 2022 13:22:55 +0800 Subject: [PATCH 3/9] fix some illegal imm construction in testcases --- tests/python/relay/test_op_level4.py | 2 +- tests/python/relay/test_pass_fuse_ops.py | 2 +- tests/python/unittest/test_target_codegen_cuda.py | 7 ++++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index 89de2f6a9520..a8eb7f406c37 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -512,7 +512,7 @@ def verify( # Test backwards slicing. verify((3, 4, 3), [-1, -1, -1], [-5, -5, -5], [-1, -1, -1], (3, 4, 3)) # Test slicing with overlarge indices. - verify((3, 4, 3), [0, 0, 0], [np.iinfo(np.int64).max] * 3, [1, 1, 1], (3, 4, 3)) + verify((3, 4, 3), [0, 0, 0], [np.iinfo(np.int32).max] * 3, [1, 1, 1], (3, 4, 3)) # Test slice mode. verify( (3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index cacce5603e5f..fe662a30766c 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -777,7 +777,7 @@ def test_fuse_dynamic_squeeze_slice_take(): squeeze = relay.op.squeeze(x, axis=[0]) strided_slice = relay.op.strided_slice( - squeeze, begin=[0, 0], end=[15130, 9223372036854775807], strides=[1, 1] + squeeze, begin=[0, 0], end=[15130, 2147483647], strides=[1, 1] ) take = relay.op.take(strided_slice, take_val, axis=0) diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index 994a85095728..96b947e20655 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -1,4 +1,5 @@ # Licensed to the Apache Software Foundation (ASF) under one + # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file @@ -194,13 +195,13 @@ def check_cuda(n, value, lanes): fun(a) np.testing.assert_equal(a.numpy(), np_a) - check_cuda(64, 0xAB, 4) + check_cuda(64, np.int8(0xAB), 4) check_cuda(64, 0, 4) check_cuda(64, -3, 4) - check_cuda(64, 0xAB, 3) + check_cuda(64, np.int8(0xAB), 3) check_cuda(64, 0, 3) check_cuda(64, -3, 3) - check_cuda(64, 0xAB, 2) + check_cuda(64, np.int8(0xAB), 2) check_cuda(64, 0, 2) check_cuda(64, -3, 2) From 0504ab72819637ad04a44177bfb900b94922b58e Mon Sep 17 00:00:00 2001 From: wrongtest Date: Mon, 22 Aug 2022 15:21:22 +0800 Subject: [PATCH 4/9] do not test i64 overflow behaviour because it is not consistent on cython and ctypes --- tests/python/unittest/test_tir_imm_values.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/python/unittest/test_tir_imm_values.py b/tests/python/unittest/test_tir_imm_values.py index e368887acc93..c321ff4e1702 100644 --- a/tests/python/unittest/test_tir_imm_values.py +++ b/tests/python/unittest/test_tir_imm_values.py @@ -60,13 +60,6 @@ def test_tir_invalid_intimm(dtype, literals): @pytest.mark.parametrize( "dtype, literals", [ - [ - "int64", - { - -9223372036854775810: 9223372036854775806, - 9223372036854775808: -9223372036854775808, - }, - ], [ "uint64", { @@ -79,7 +72,6 @@ def test_tir_invalid_intimm(dtype, literals): def test_tir_large_py_int_literals(dtype, literals): """ For large uint value, use LargeUIntImm intrin, - For large int value exceed int64_t value range, the value is wrapped back. """ for l in literals: x = tir.const(l, dtype) From 83609845cd46f65f65411f5fe6bcbc95e4e3526b Mon Sep 17 00:00:00 2001 From: wrongtest Date: Wed, 24 Aug 2022 19:38:17 +0800 Subject: [PATCH 5/9] fix float32 testcase --- tests/python/unittest/test_tir_imm_values.py | 31 +++++++++++++++----- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/tests/python/unittest/test_tir_imm_values.py b/tests/python/unittest/test_tir_imm_values.py index c321ff4e1702..b4dff940e364 100644 --- a/tests/python/unittest/test_tir_imm_values.py +++ b/tests/python/unittest/test_tir_imm_values.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. import math +import random +import sys import numpy as np import tvm import tvm.testing @@ -133,14 +135,29 @@ def imm_overflow_fp64() -> T.float64: f = tvm.build(imm_overflow_fp64, target="llvm") assert math.isinf(f()) - # Behavior check: disable fp16 folding - assert float(tir.const(1.0, "float32") * tir.const(2.0, "float32")) == 2.0 - assert not isinstance(tir.const(1.0, "float16") * tir.const(2.0, "float16"), tir.FloatImm) - # Behavior check: folding when fp32 overflow get infinity - x = np.float32(3.4028235e37) - y = np.float32(3.4028235e37) - assert math.isinf(float(tir.const(x, "float32") * tir.const(y, "float32"))) +@tvm.testing.requires_llvm() +def test_tir_floatimm_const_fold(): + # Behavior check: folding fp32 match platform f32 arithmetic + @T.prim_func + def float_imm_multiply(x: T.float32, y: T.float32) -> T.float32: + T.evaluate(T.ret(x * y, dtype="float32")) + + fmul = tvm.build(float_imm_multiply, target="llvm") + + # overflow + for x, y in [(3.14e30, 3.14e30), (-3.14e30, 3.14e30)]: + assert float(tir.const(x, "float32") * tir.const(y, "float32")) == fmul(x, y) + + seed = random.randrange(sys.maxsize) + print( + "\nThis test is intentionally non-deterministic, " + "if it fails please report it in github issue together with this seed {}\n".format(seed) + ) + np.random.seed(seed) + x = np.random.uniform(np.finfo("float32").min, np.finfo("float32").max) + y = np.random.uniform(np.finfo("float32").min, np.finfo("float32").max) + assert float(tir.const(x, "float32") * tir.const(y, "float32")) == fmul(x, y), f"{x} * {y}" if __name__ == "__main__": From 5e5d84da1189b7ba46301aa9f73d4e5e2005db1e Mon Sep 17 00:00:00 2001 From: wrongtest Date: Wed, 24 Aug 2022 19:46:45 +0800 Subject: [PATCH 6/9] auto-inferred dtype should be int64 when value exceeds int32 range --- python/tvm/runtime/object_generic.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py index 7a55d3ef244e..55ad7b098a18 100644 --- a/python/tvm/runtime/object_generic.py +++ b/python/tvm/runtime/object_generic.py @@ -118,8 +118,11 @@ def _scalar_type_inference(value): # We intentionally convert the float to float32 since it's more common in DL. dtype = "float32" elif isinstance(value, int): - # We intentionally convert the python int to int32 since it's more common in DL. - dtype = "int32" + # We intentionally prefer convert the python int to int32 since it's more common in DL. + if -2147483648 <= value < 2147483648: + dtype = "int32" + else: + dtype = "int64" else: raise NotImplementedError( "Cannot automatically inference the type." " value={}".format(value) From bcf0ac2dfe82eab115698c24cf7a311a150a3ab2 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Thu, 25 Aug 2022 02:22:27 +0800 Subject: [PATCH 7/9] add floatimm range check for fp16 and fp32 --- include/tvm/tir/op.h | 7 +- python/tvm/runtime/object_generic.py | 9 ++- src/ir/expr.cc | 17 +++++ src/support/scalars.cc | 4 -- src/support/scalars.h | 4 ++ tests/python/unittest/test_tir_imm_values.py | 70 +++++++++++++++----- 6 files changed, 86 insertions(+), 25 deletions(-) diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 258d2511b768..df1977f13559 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -911,7 +911,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) if (t.is_uint()) { // Use IntImm if it is a small integer uint64_t uval = static_cast(value); - if (static_cast(value) < 0) { + if (value < static_cast(0)) { LOG(FATAL) << "cannot make uint from negative value " << value; } else if (uval <= static_cast(std::numeric_limits::max())) { return IntImm(t, static_cast(value), span); @@ -934,6 +934,11 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) return PrimExpr(); } +template <> +inline PrimExpr MakeConstScalar(DataType t, bool value, Span span) { + return MakeConstScalar(t, static_cast(value), span); +} + template inline PrimExpr make_const(DataType t, ValueType value, Span span) { if (t.lanes() == 1) { diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py index 55ad7b098a18..05426dfb1aeb 100644 --- a/python/tvm/runtime/object_generic.py +++ b/python/tvm/runtime/object_generic.py @@ -115,11 +115,14 @@ def _scalar_type_inference(value): elif isinstance(value, bool): dtype = "bool" elif isinstance(value, float): - # We intentionally convert the float to float32 since it's more common in DL. - dtype = "float32" + # We intentionally prefer convert the float to float32 since it's more common in DL. + if -3.40282347e38 <= value <= 3.40282347e38: + dtype = "float32" + else: + dtype = "float64" elif isinstance(value, int): # We intentionally prefer convert the python int to int32 since it's more common in DL. - if -2147483648 <= value < 2147483648: + if -2147483648 <= value <= 2147483647: dtype = "int32" else: dtype = "int64" diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 3a09b8458125..c926cc56e89a 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -33,6 +33,8 @@ #include #include +#include "../support/scalars.h" + namespace tvm { PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm(DataType::Int(32), value)) {} @@ -116,6 +118,21 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) FloatImm::FloatImm(DataType dtype, double value, Span span) { ICHECK_EQ(dtype.lanes(), 1) << "ValueError: FloatImm can only take scalar."; + + // check range for float32 and float16 since they have specified range. + if (!std::isinf(value) && !std::isnan(value)) { + if (dtype.bits() == 32) { + ICHECK_GE(value, std::numeric_limits::lowest()) + << "ValueError: Literal value " << value << " exceeds minimum of " << dtype; + ICHECK_LE(value, std::numeric_limits::max()) + << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; + } else if (dtype.is_float16()) { + ICHECK_GE(value, -support::kMaxFloat16) + << "ValueError: Literal value " << value << " exceeds minimum of " << dtype; + ICHECK_LE(value, support::kMaxFloat16) + << "ValueError: Literal value " << value << " exceeds maximum of " << dtype; + } + } ObjectPtr node = make_object(); node->dtype = dtype; node->value = value; diff --git a/src/support/scalars.cc b/src/support/scalars.cc index 9caa7ca58915..0ab16899bae9 100644 --- a/src/support/scalars.cc +++ b/src/support/scalars.cc @@ -174,10 +174,6 @@ IntImm ValueToIntImm(int64_t value, int width) { } } -// 2^15 * (1 + 1023/1024) -// See https://en.wikipedia.org/wiki/Half-precision_floating-point_format -constexpr double kMaxFloat16 = 65504.0; - FloatImm ValueToFloatImm(double value, int width) { if (width == 16) { if (!std::isinf(value) && (value < -kMaxFloat16 || value > kMaxFloat16)) { diff --git a/src/support/scalars.h b/src/support/scalars.h index 60b8fc40a8de..2fdbb001d922 100644 --- a/src/support/scalars.h +++ b/src/support/scalars.h @@ -61,6 +61,10 @@ std::string FloatImmToString(const FloatImm& float_imm); IntImm ValueToIntImm(int64_t value, int width); FloatImm ValueToFloatImm(double value, int width); +// 2^15 * (1 + 1023/1024) +// See https://en.wikipedia.org/wiki/Half-precision_floating-point_format +constexpr double kMaxFloat16 = 65504.0; + } // namespace support } // namespace tvm diff --git a/tests/python/unittest/test_tir_imm_values.py b/tests/python/unittest/test_tir_imm_values.py index b4dff940e364..b18a936eeafc 100644 --- a/tests/python/unittest/test_tir_imm_values.py +++ b/tests/python/unittest/test_tir_imm_values.py @@ -103,6 +103,35 @@ def compare_float_value(value, expect): assert math.isinf(expect) +@pytest.mark.parametrize( + "dtype, literals", + [ + ["float16", [-65504.0, 3.14, 65504.0, np.inf, np.nan]], + ["bfloat16", [-3.38953139e38, 3.38953139e38, 3.14]], + ["float32", [np.finfo("float32").min, 3.14, np.finfo("float32").max, np.inf, np.nan]], + ["float64", [np.finfo("float64").min, 3.14, np.finfo("float64").max, np.inf, np.nan]], + ], +) +def test_tir_make_floatimm(dtype, literals): + for l in literals: + imm = tir.const(l, dtype) + compare_float_value(imm.value, l) + + +@pytest.mark.parametrize( + "dtype, literals", + [ + ["float16", [-65505.0, 65505.0]], + ["float32", [-3.402e39, 3.402e39]], + ], +) +def test_tir_invalid_floatimm(dtype, literals): + """Currently only fp16 and fp32 have range check.""" + for l in literals: + with pytest.raises(tvm.TVMError): + tir.const(l, dtype) + + @pytest.mark.parametrize("dtype", ["float16", "float32", "float64"]) @pytest.mark.parametrize("literal", [3.14, np.nan, np.inf]) def test_tir_special_floatimms(dtype, literal): @@ -111,23 +140,9 @@ def test_tir_special_floatimms(dtype, literal): @tvm.testing.requires_llvm() -def test_tir_floatimm_overflow(): - # Behavior check: if literal value is out of dtype range, the +def test_tir_too_large_literal_f64(): + # Behavior check: if literal f64 value is out of dtype range, the # object is still constructed, and eval to infinity. - @T.prim_func - def imm_overflow_fp16() -> T.float16: - T.evaluate(T.ret(T.float16(65536), dtype="float16")) - - f = tvm.build(imm_overflow_fp16, target="llvm") - assert math.isinf(f()) - - @T.prim_func - def imm_overflow_fp32() -> T.float32: - T.evaluate(T.ret(T.float32(3.4028e39), dtype="float32")) - - f = tvm.build(imm_overflow_fp32, target="llvm") - assert math.isinf(f()) - @T.prim_func def imm_overflow_fp64() -> T.float64: T.evaluate(T.ret(T.float64(1.7976e309), dtype="float64")) @@ -136,6 +151,27 @@ def imm_overflow_fp64() -> T.float64: assert math.isinf(f()) +@pytest.mark.parametrize( + "literal, expect_dtype", + [ + (256, "int32"), + (2147483647, "int32"), + (-2147483648, "int32"), + (2147483648, "int64"), + (-2147483649, "int64"), + (3.14159, "float32"), + (np.finfo("float32").min, "float32"), + (np.finfo("float32").max, "float32"), + (-3.402e39, "float64"), + (3.402e39, "float64"), + ], +) +def test_tir_const_auto_dtype(literal, expect_dtype): + x = tir.const(literal, dtype=None) + assert x.dtype == expect_dtype + assert x.value == literal + + @tvm.testing.requires_llvm() def test_tir_floatimm_const_fold(): # Behavior check: folding fp32 match platform f32 arithmetic @@ -149,7 +185,7 @@ def float_imm_multiply(x: T.float32, y: T.float32) -> T.float32: for x, y in [(3.14e30, 3.14e30), (-3.14e30, 3.14e30)]: assert float(tir.const(x, "float32") * tir.const(y, "float32")) == fmul(x, y) - seed = random.randrange(sys.maxsize) + seed = random.randint(0, 2147483648) print( "\nThis test is intentionally non-deterministic, " "if it fails please report it in github issue together with this seed {}\n".format(seed) From 59f7db445c1ac7a137fc7f215263d05e04d6600c Mon Sep 17 00:00:00 2001 From: wrongtest Date: Sun, 28 Aug 2022 21:23:59 +0800 Subject: [PATCH 8/9] add more folding testcases and fix store fp32 folding result to double --- python/tvm/script/tir/intrin.py | 5 + src/arith/const_fold.h | 47 ++- tests/python/unittest/test_tir_imm_values.py | 402 ++++++++++++++++++- 3 files changed, 420 insertions(+), 34 deletions(-) diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py index 382431c2296a..73a14ead149b 100644 --- a/python/tvm/script/tir/intrin.py +++ b/python/tvm/script/tir/intrin.py @@ -85,6 +85,11 @@ def truncmod(x, y, span): return tvm.tir.truncmod(x, y, span) +@register +def truncdiv(x, y, span): + return tvm.tir.truncdiv(x, y, span) + + @register def ceildiv(x, y, span): return tvm.tir.ceildiv(x, y, span) diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index 51799278a6db..ea6192724bee 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -29,6 +29,7 @@ #include #include +#include #include "int_operator.h" @@ -74,7 +75,7 @@ inline bool IsIndexType(const DataType& type) { } /*! \brief Helper to get const folding result repr in int64. */ -inline int64_t GetInt64FoldResultRepr(int64_t x, const DataType& dtype) { +inline int64_t GetFoldResultInt64Repr(int64_t x, const DataType& dtype) { if (dtype.bits() < 64) { x &= (1LL << dtype.bits()) - 1; } @@ -86,6 +87,20 @@ inline int64_t GetInt64FoldResultRepr(int64_t x, const DataType& dtype) { return x; } +/*! \brief Helper to get fp32 const folding result repr in double. */ +inline double GetFoldResultDoubleRepr(float x) { + double res = static_cast(x); + if (std::isinf(res) || std::isnan(res)) { + return res; + } + if (res < std::numeric_limits::lowest()) { + return -std::numeric_limits::infinity(); + } else if (res > std::numeric_limits::max()) { + return std::numeric_limits::infinity(); + } + return res; +} + #define TVM_ARITH_CONST_PROPAGATION(BODY) \ using tir::FloatImmNode; \ const IntImmNode* pa = a.as(); \ @@ -110,13 +125,14 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) { int64_t res = pa->value + pb->value; - return IntImm(rtype, GetInt64FoldResultRepr(res, rtype)); + return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); } if (pa && pa->value == 0) return b; if (pb && pb->value == 0) return a; if (fa && fb) { if (rtype.bits() == 32) { - return FloatImm(rtype, static_cast(fa->value) + static_cast(fb->value)); + return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast(fa->value) + + static_cast(fb->value))); } else if (rtype.bits() == 64) { return FloatImm(rtype, fa->value + fb->value); } else { @@ -139,12 +155,13 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) { int64_t res = pa->value - pb->value; - return IntImm(rtype, GetInt64FoldResultRepr(res, rtype)); + return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); } if (pb && pb->value == 0) return a; if (fa && fb) { if (rtype.bits() == 32) { - return FloatImm(rtype, static_cast(fa->value) - static_cast(fb->value)); + return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast(fa->value) - + static_cast(fb->value))); } else if (rtype.bits() == 64) { return FloatImm(rtype, fa->value - fb->value); } else { @@ -162,7 +179,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { const DataType& rtype = a.dtype(); if (pa && pb) { int64_t res = pa->value * pb->value; - return IntImm(rtype, GetInt64FoldResultRepr(res, rtype)); + return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); } if (pa) { if (pa->value == 1) return b; @@ -174,7 +191,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } if (fa && fb) { if (rtype.bits() == 32) { - return FloatImm(rtype, static_cast(fa->value) * static_cast(fb->value)); + return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast(fa->value) * + static_cast(fb->value))); } else if (rtype.bits() == 64) { return FloatImm(rtype, fa->value * fb->value); } else { @@ -202,7 +220,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { // NOTE: this will assumes truc div. ICHECK_NE(pb->value, 0) << "Divide by zero"; int64_t res = pa->value / pb->value; - return IntImm(rtype, GetInt64FoldResultRepr(res, rtype)); + return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); } if (pa) { if (pa->value == 0) return a; @@ -213,7 +231,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } if (fa && fb && fb->value != 0) { if (rtype.bits() == 32) { - return FloatImm(rtype, static_cast(fa->value) / static_cast(fb->value)); + return FloatImm(rtype, GetFoldResultDoubleRepr(static_cast(fa->value) / + static_cast(fb->value))); } else if (rtype.bits() == 64) { return FloatImm(rtype, fa->value / fb->value); } else { @@ -236,7 +255,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { if (pa && pb) { ICHECK_NE(pb->value, 0) << "Divide by zero"; int64_t res = pa->value % pb->value; - return IntImm(rtype, GetInt64FoldResultRepr(res, rtype)); + return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); } if (pa) { if (pa->value == 0) return a; @@ -256,7 +275,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { if (pa && pb) { ICHECK_NE(pb->value, 0) << "Divide by zero"; int64_t res = arith::floordiv(pa->value, pb->value); - return IntImm(rtype, GetInt64FoldResultRepr(res, rtype)); + return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); } if (pa) { if (pa->value == 0) return a; @@ -267,8 +286,8 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { } if (fa && fb && fb->value != 0) { if (rtype.bits() == 32) { - return FloatImm(rtype, - std::floor(static_cast(fa->value) / static_cast(fb->value))); + return FloatImm(rtype, GetFoldResultDoubleRepr(std::floor(static_cast(fa->value) / + static_cast(fb->value)))); } else if (rtype.bits() == 64) { return FloatImm(rtype, std::floor(fa->value / fb->value)); } else { @@ -291,7 +310,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { if (pa && pb) { ICHECK_NE(pb->value, 0) << "Divide by zero"; int64_t res = arith::floormod(pa->value, pb->value); - return IntImm(rtype, GetInt64FoldResultRepr(res, rtype)); + return IntImm(rtype, GetFoldResultInt64Repr(res, rtype)); } if (pa) { if (pa->value == 0) return a; diff --git a/tests/python/unittest/test_tir_imm_values.py b/tests/python/unittest/test_tir_imm_values.py index b18a936eeafc..c27832d8c081 100644 --- a/tests/python/unittest/test_tir_imm_values.py +++ b/tests/python/unittest/test_tir_imm_values.py @@ -16,7 +16,6 @@ # under the License. import math import random -import sys import numpy as np import tvm import tvm.testing @@ -85,22 +84,23 @@ def test_tir_large_py_int_literals(dtype, literals): def test_tir_intimm_overflow(): - assert int(tir.const(127, "int8") + tir.const(1, "int8")) == -128 - assert int(tir.const(127, "int8") + tir.const(2, "int8")) == -127 assert int(tir.const(255, "uint8") + tir.const(1, "uint8")) == 0 assert int(tir.const(2**31 - 1, "int32") + tir.const(1, "int32")) == -(2**31) assert int(tir.const(2**32 - 1, "uint32") + tir.const(1, "uint32")) == 0 assert int(tir.const(2**63 - 1, "int64") + tir.const(1, "int64")) == -(2**63) assert int(tir.const(2**32, "uint64") * tir.const(2**32, "uint64")) == 0 + # customized int types + assert int(tir.const(7, "int4") + tir.const(1, "int4")) == -8 + assert int(tir.const(2**39 - 1, "int40") + tir.const(1, "int40")) == -(2**39) -def compare_float_value(value, expect): +def compare_float_value(value, expect, msg): if math.isfinite(value): - assert value == expect + assert value == expect, f"{value} vs {expect}, {msg}" elif math.isnan(value): - assert math.isnan(expect) + assert math.isnan(expect), f"{value} vs {expect}, {msg}" elif math.isinf(value): - assert math.isinf(expect) + assert math.isinf(expect), f"{value} vs {expect}, {msg}" @pytest.mark.parametrize( @@ -115,7 +115,7 @@ def compare_float_value(value, expect): def test_tir_make_floatimm(dtype, literals): for l in literals: imm = tir.const(l, dtype) - compare_float_value(imm.value, l) + compare_float_value(imm.value, l, "imm value should match feed value") @pytest.mark.parametrize( @@ -136,7 +136,7 @@ def test_tir_invalid_floatimm(dtype, literals): @pytest.mark.parametrize("literal", [3.14, np.nan, np.inf]) def test_tir_special_floatimms(dtype, literal): x = tir.const(literal, dtype) - compare_float_value(x.value, literal) + compare_float_value(x.value, literal, "imm value should match feed value") @tvm.testing.requires_llvm() @@ -172,28 +172,390 @@ def test_tir_const_auto_dtype(literal, expect_dtype): assert x.value == literal +def check_tir_const_fold( + dtype, foldf, calcf, x_range=None, y_range=None, expect=None, skip_overflow=False +): + """Helper to check constant folding behavior + + Parameters + ---------- + dtype: str + Datatype of constants + + foldf: (x, y) -> z + Folding function to call + + calcf: (x, y) -> z + Compiled calculation function to call + + x_range: Union[int, float, tuple] + Single value or value range [min, max] + + y_range: Union[int, float, tuple] + Single value or value range [min, max] + + expect: Union[int, float] + Expected calculation result + + skip_overflow: bool + Skip assertion if the overflow happens + """ + seed = random.randint(0, 2147483648) + np.random.seed(seed) + ninfo = np.finfo(dtype) if dtype.startswith("float") else np.iinfo(dtype) + + if x_range is None: + x_range = (ninfo.min, ninfo.max) + if isinstance(x_range, (int, float)): + x = x_range + elif dtype.startswith("int") or dtype.startswith("uint"): + x = np.random.randint(x_range[0], x_range[1] + 1) + else: + x = np.random.uniform(x_range[0], x_range[1]) + + if y_range is None: + y_range = (ninfo.min, ninfo.max) + if isinstance(y_range, (int, float)): + y = y_range + elif dtype.startswith("int") or dtype.startswith("uint"): + y = np.random.randint(y_range[0], y_range[1] + 1) + else: + y = np.random.uniform(y_range[0], y_range[1]) + + if skip_overflow: + py_res = foldf(x, y) + if isinstance(py_res, (tir.IntImm, tir.FloatImm)): + py_res = py_res.value + if not (ninfo.min <= py_res <= ninfo.max): + # If the result overflow, certain arithmetics is non-defined + # thus we intentionally do not make the test failed. + return + + fold_res = foldf(tir.const(x, dtype), tir.const(y, dtype)) + calc_res = calcf(x, y) + + flaky_msg = ( + f"{dtype} ({x}, {y}, {expect}) const folding check failed.\n" + + "This test is intentionally non-deterministic, " + + f"if it fails please report it in github issue together with this seed {seed}\n" + ) + compare_float_value(calc_res, fold_res.value, flaky_msg) + if expect: + compare_float_value(expect, calc_res, flaky_msg) + + @tvm.testing.requires_llvm() def test_tir_floatimm_const_fold(): - # Behavior check: folding fp32 match platform f32 arithmetic + """Behavior check: folding fp32 match platform f32 arithmetic""" + @T.prim_func def float_imm_multiply(x: T.float32, y: T.float32) -> T.float32: T.evaluate(T.ret(x * y, dtype="float32")) + @T.prim_func + def float_imm_add(x: T.float32, y: T.float32) -> T.float32: + T.evaluate(T.ret(x + y, dtype="float32")) + + @T.prim_func + def float_imm_sub(x: T.float32, y: T.float32) -> T.float32: + T.evaluate(T.ret(x - y, dtype="float32")) + + @T.prim_func + def float_imm_div(x: T.float32, y: T.float32) -> T.float32: + T.evaluate(T.ret(x / y, dtype="float32")) + fmul = tvm.build(float_imm_multiply, target="llvm") + fadd = tvm.build(float_imm_add, target="llvm") + fsub = tvm.build(float_imm_sub, target="llvm") + fdiv = tvm.build(float_imm_div, target="llvm") # overflow - for x, y in [(3.14e30, 3.14e30), (-3.14e30, 3.14e30)]: - assert float(tir.const(x, "float32") * tir.const(y, "float32")) == fmul(x, y) + check_tir_const_fold("float32", lambda x, y: x * y, fmul, 3.0e30, 3.0e30, np.inf) + check_tir_const_fold("float32", lambda x, y: x * y, fmul, 3.0e30, -3.0e30, -np.inf) + check_tir_const_fold("float32", lambda x, y: x / y, fdiv, 3.0e30, 3.0e-30, np.inf) - seed = random.randint(0, 2147483648) - print( - "\nThis test is intentionally non-deterministic, " - "if it fails please report it in github issue together with this seed {}\n".format(seed) + # divide by zero + with pytest.raises(tvm.TVMError): + check_tir_const_fold("float32", lambda x, y: x / y, fdiv, 1.0, 0.0) + + # nan and inf + check_tir_const_fold("float32", lambda x, y: x + y, fadd, 1.0, np.nan, np.nan) + check_tir_const_fold("float32", lambda x, y: x + y, fadd, 1.0, np.inf, np.inf) + check_tir_const_fold("float32", lambda x, y: x + y, fadd, 1.0, -np.inf, -np.inf) + + # randomized check + check_tir_const_fold("float32", lambda x, y: x * y, fmul) + check_tir_const_fold("float32", lambda x, y: x + y, fadd) + check_tir_const_fold("float32", lambda x, y: x - y, fsub) + check_tir_const_fold( + "float32", lambda x, y: x / y, fdiv, y_range=(0.01, np.finfo("float32").max) + ) + + +@tvm.testing.requires_llvm() +def test_tir_int8_const_fold(): + """Behavior check: folding i8 operation match platform i8 arithmetic""" + + @T.prim_func + def imm_multiply(x: T.int8, y: T.int8) -> T.int8: + T.evaluate(T.ret(x * y, dtype="int8")) + + @T.prim_func + def imm_add(x: T.int8, y: T.int8) -> T.int8: + T.evaluate(T.ret(x + y, dtype="int8")) + + @T.prim_func + def imm_sub(x: T.int8, y: T.int8) -> T.int8: + T.evaluate(T.ret(x - y, dtype="int8")) + + @T.prim_func + def imm_truncdiv(x: T.int8, y: T.int8) -> T.int8: + T.evaluate(T.ret(T.truncdiv(x, y), dtype="int8")) + + @T.prim_func + def imm_floordiv(x: T.int8, y: T.int8) -> T.int8: + T.evaluate(T.ret(T.floordiv(x, y), dtype="int8")) + + fmul = tvm.build(imm_multiply, target="llvm") + fadd = tvm.build(imm_add, target="llvm") + fsub = tvm.build(imm_sub, target="llvm") + ffloordiv = tvm.build(imm_floordiv, target="llvm") + ftruncdiv = tvm.build(imm_truncdiv, target="llvm") + + # overflow + check_tir_const_fold("int8", lambda x, y: x + y, fadd, 127, 1, -128) + check_tir_const_fold("int8", lambda x, y: x * y, fmul, 127, 127, 1) + + # divide by zero + with pytest.raises(tvm.TVMError): + check_tir_const_fold("int8", lambda x, y: tir.floordiv(x, y), ffloordiv, 1, 0) + with pytest.raises(tvm.TVMError): + check_tir_const_fold("int8", lambda x, y: tir.truncdiv(x, y), ftruncdiv, 1, 0) + + # i8 mod folding is not implemented + assert not isinstance(tir.floormod(tir.const(7, "int8"), tir.const(3, "int8")), tir.IntImm) + assert not isinstance(tir.truncmod(tir.const(7, "int8"), tir.const(3, "int8")), tir.IntImm) + + # randomized check + check_tir_const_fold("int8", lambda x, y: x * y, fmul) + check_tir_const_fold("int8", lambda x, y: x + y, fadd) + check_tir_const_fold("int8", lambda x, y: x - y, fsub) + check_tir_const_fold( + "int8", lambda x, y: tir.floordiv(x, y), ffloordiv, y_range=(1, np.iinfo("int8").max) + ) + check_tir_const_fold( + "int8", lambda x, y: tir.truncdiv(x, y), ftruncdiv, y_range=(1, np.iinfo("int8").max) + ) + + +@tvm.testing.requires_llvm() +def test_tir_uint8_const_fold(): + """Behavior check: folding u8 operation match platform u8 arithmetic""" + + @T.prim_func + def imm_multiply(x: T.uint8, y: T.uint8) -> T.uint8: + T.evaluate(T.ret(x * y, dtype="uint8")) + + @T.prim_func + def imm_add(x: T.uint8, y: T.uint8) -> T.uint8: + T.evaluate(T.ret(x + y, dtype="uint8")) + + @T.prim_func + def imm_sub(x: T.uint8, y: T.uint8) -> T.uint8: + T.evaluate(T.ret(x - y, dtype="uint8")) + + @T.prim_func + def imm_truncdiv(x: T.uint8, y: T.uint8) -> T.uint8: + T.evaluate(T.ret(T.truncdiv(x, y), dtype="uint8")) + + @T.prim_func + def imm_floordiv(x: T.uint8, y: T.uint8) -> T.uint8: + T.evaluate(T.ret(T.floordiv(x, y), dtype="uint8")) + + fmul = tvm.build(imm_multiply, target="llvm") + fadd = tvm.build(imm_add, target="llvm") + fsub = tvm.build(imm_sub, target="llvm") + ffloordiv = tvm.build(imm_floordiv, target="llvm") + ftruncdiv = tvm.build(imm_truncdiv, target="llvm") + + # overflow + check_tir_const_fold("uint8", lambda x, y: x + y, fadd, 255, 1, 0) + + # zero sub + with pytest.raises(tvm.TVMError): + check_tir_const_fold("uint8", lambda x, y: x - y, fsub, 0, 10) + + # divide by zero + with pytest.raises(tvm.TVMError): + check_tir_const_fold("uint8", lambda x, y: tir.floordiv(x, y), ffloordiv, 1, 0) + with pytest.raises(tvm.TVMError): + check_tir_const_fold("uint8", lambda x, y: tir.truncdiv(x, y), ftruncdiv, 1, 0) + + # u8 mod folding is not implemented + assert not isinstance(tir.floormod(tir.const(7, "uint8"), tir.const(3, "uint8")), tir.IntImm) + assert not isinstance(tir.truncmod(tir.const(7, "uint8"), tir.const(3, "uint8")), tir.IntImm) + + # randomized check + check_tir_const_fold("uint8", lambda x, y: x * y, fmul) + check_tir_const_fold("uint8", lambda x, y: x + y, fadd) + check_tir_const_fold("uint8", lambda x, y: x - y, fsub) + check_tir_const_fold( + "uint8", lambda x, y: tir.floordiv(x, y), ffloordiv, y_range=(1, np.iinfo("uint8").max) + ) + check_tir_const_fold( + "uint8", lambda x, y: tir.truncdiv(x, y), ftruncdiv, y_range=(1, np.iinfo("uint8").max) + ) + + +@tvm.testing.requires_llvm() +def test_tir_int32_const_fold(): + """Behavior check: folding i32 operation match platform i32 arithmetic""" + + @T.prim_func + def imm_multiply(x: T.int32, y: T.int32) -> T.int32: + T.evaluate(T.ret(x * y, dtype="int32")) + + @T.prim_func + def imm_add(x: T.int32, y: T.int32) -> T.int32: + T.evaluate(T.ret(x + y, dtype="int32")) + + @T.prim_func + def imm_sub(x: T.int32, y: T.int32) -> T.int32: + T.evaluate(T.ret(x - y, dtype="int32")) + + @T.prim_func + def imm_truncdiv(x: T.int32, y: T.int32) -> T.int32: + T.evaluate(T.ret(T.truncdiv(x, y), dtype="int32")) + + @T.prim_func + def imm_truncmod(x: T.int32, y: T.int32) -> T.int32: + T.evaluate(T.ret(T.truncmod(x, y), dtype="int32")) + + @T.prim_func + def imm_floordiv(x: T.int32, y: T.int32) -> T.int32: + T.evaluate(T.ret(T.floordiv(x, y), dtype="int32")) + + @T.prim_func + def imm_floormod(x: T.int32, y: T.int32) -> T.int32: + T.evaluate(T.ret(T.floormod(x, y), dtype="int32")) + + fmul = tvm.build(imm_multiply, target="llvm") + fadd = tvm.build(imm_add, target="llvm") + fsub = tvm.build(imm_sub, target="llvm") + ffloordiv = tvm.build(imm_floordiv, target="llvm") + ffloormod = tvm.build(imm_floormod, target="llvm") + ftruncdiv = tvm.build(imm_truncdiv, target="llvm") + ftruncmod = tvm.build(imm_truncmod, target="llvm") + + # i32 overflow is not specified, only check for range + assert -(2**31) <= int(tir.const(2**31 - 1, "int32") + tir.const(1, "int32")) < 2**31 + assert -(2**31) <= int(tir.const(-(2**31), "int32") - tir.const(1, "int32")) < 2**31 + + # divide by zero + with pytest.raises(tvm.TVMError): + check_tir_const_fold("int32", lambda x, y: tir.floordiv(x, y), ffloordiv, 1, 0) + with pytest.raises(tvm.TVMError): + check_tir_const_fold("int32", lambda x, y: tir.floormod(x, y), ffloormod, 1, 0) + with pytest.raises(tvm.TVMError): + check_tir_const_fold("int32", lambda x, y: tir.truncdiv(x, y), ftruncdiv, 1, 0) + with pytest.raises(tvm.TVMError): + check_tir_const_fold("int32", lambda x, y: tir.truncmod(x, y), ftruncmod, 1, 0) + + # randomized check + check_tir_const_fold("int32", lambda x, y: x * y, fmul, skip_overflow=True) + check_tir_const_fold("int32", lambda x, y: x + y, fadd, skip_overflow=True) + check_tir_const_fold("int32", lambda x, y: x - y, fsub, skip_overflow=True) + check_tir_const_fold( + "int32", + lambda x, y: tir.floordiv(x, y), + ffloordiv, + y_range=(1, np.iinfo("int32").max), + skip_overflow=True, + ) + check_tir_const_fold( + "int32", + lambda x, y: tir.truncdiv(x, y), + ftruncdiv, + y_range=(1, np.iinfo("int32").max), + skip_overflow=True, + ) + check_tir_const_fold( + "int32", + lambda x, y: tir.floormod(x, y), + ffloormod, + y_range=(1, np.iinfo("int32").max), + skip_overflow=False, + ) + check_tir_const_fold( + "int32", + lambda x, y: tir.truncmod(x, y), + ftruncmod, + y_range=(1, np.iinfo("int32").max), + skip_overflow=False, + ) + + +@tvm.testing.requires_llvm() +def test_tir_uint32_const_fold(): + """Behavior check: folding u32 operation match platform u32 arithmetic""" + + @T.prim_func + def imm_multiply(x: T.uint32, y: T.uint32) -> T.uint32: + T.evaluate(T.ret(x * y, dtype="uint32")) + + @T.prim_func + def imm_add(x: T.uint32, y: T.uint32) -> T.uint32: + T.evaluate(T.ret(x + y, dtype="uint32")) + + @T.prim_func + def imm_sub(x: T.uint32, y: T.uint32) -> T.uint32: + T.evaluate(T.ret(x - y, dtype="uint32")) + + @T.prim_func + def imm_truncdiv(x: T.uint32, y: T.uint32) -> T.uint32: + T.evaluate(T.ret(T.truncdiv(x, y), dtype="uint32")) + + @T.prim_func + def imm_floordiv(x: T.uint32, y: T.uint32) -> T.uint32: + T.evaluate(T.ret(T.floordiv(x, y), dtype="uint32")) + + fmul = tvm.build(imm_multiply, target="llvm") + fadd = tvm.build(imm_add, target="llvm") + fsub = tvm.build(imm_sub, target="llvm") + ffloordiv = tvm.build(imm_floordiv, target="llvm") + ftruncdiv = tvm.build(imm_truncdiv, target="llvm") + + # u32 overflow is not specified, only check for range + assert 0 <= int(tir.const(2**32 - 1, "uint32") + tir.const(1, "uint32")) < 2**32 + + # divide by zero + with pytest.raises(tvm.TVMError): + check_tir_const_fold("uint32", lambda x, y: tir.floordiv(x, y), ffloordiv, 1, 0) + with pytest.raises(tvm.TVMError): + check_tir_const_fold("uint32", lambda x, y: tir.truncdiv(x, y), ftruncdiv, 1, 0) + + # u8 mod folding is not implemented + assert not isinstance(tir.floormod(tir.const(7, "uint32"), tir.const(3, "uint32")), tir.IntImm) + assert not isinstance(tir.truncmod(tir.const(7, "uint32"), tir.const(3, "uint32")), tir.IntImm) + + # randomized check + check_tir_const_fold("uint32", lambda x, y: x * y, fmul, skip_overflow=True) + check_tir_const_fold("uint32", lambda x, y: x + y, fadd, skip_overflow=True) + check_tir_const_fold("uint32", lambda x, y: x - y, fsub, skip_overflow=True) + check_tir_const_fold( + "uint32", + lambda x, y: tir.floordiv(x, y), + ffloordiv, + y_range=(1, np.iinfo("uint32").max), + skip_overflow=False, + ) + check_tir_const_fold( + "uint32", + lambda x, y: tir.truncdiv(x, y), + ftruncdiv, + y_range=(1, np.iinfo("uint32").max), + skip_overflow=False, ) - np.random.seed(seed) - x = np.random.uniform(np.finfo("float32").min, np.finfo("float32").max) - y = np.random.uniform(np.finfo("float32").min, np.finfo("float32").max) - assert float(tir.const(x, "float32") * tir.const(y, "float32")) == fmul(x, y), f"{x} * {y}" if __name__ == "__main__": From f900f81743c6b015998d71e0871583a5c7bb859d Mon Sep 17 00:00:00 2001 From: wrongtest Date: Mon, 29 Aug 2022 13:12:14 +0800 Subject: [PATCH 9/9] fix i386 fp16 cases --- src/arith/const_fold.h | 6 +++ tests/python/unittest/test_tir_imm_values.py | 51 +++++++++++++------- 2 files changed, 39 insertions(+), 18 deletions(-) diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index ea6192724bee..d0e09a1a7429 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -93,9 +93,15 @@ inline double GetFoldResultDoubleRepr(float x) { if (std::isinf(res) || std::isnan(res)) { return res; } + // certain platform (eg, on gcc7-i386) do the folding arithmetic + // on float and write back to double is optimized to double + // precision arithmetic, this is legal and we check the output + // range thus to ensure consistency when the float result is inf. if (res < std::numeric_limits::lowest()) { + LOG(WARNING) << "underlying float value overflow"; return -std::numeric_limits::infinity(); } else if (res > std::numeric_limits::max()) { + LOG(WARNING) << "underlying float value overflow"; return std::numeric_limits::infinity(); } return res; diff --git a/tests/python/unittest/test_tir_imm_values.py b/tests/python/unittest/test_tir_imm_values.py index c27832d8c081..a2a19a09ad87 100644 --- a/tests/python/unittest/test_tir_imm_values.py +++ b/tests/python/unittest/test_tir_imm_values.py @@ -96,7 +96,7 @@ def test_tir_intimm_overflow(): def compare_float_value(value, expect, msg): if math.isfinite(value): - assert value == expect, f"{value} vs {expect}, {msg}" + assert np.abs(value - expect) < 1e-5, f"{value} vs {expect}, {msg}" elif math.isnan(value): assert math.isnan(expect), f"{value} vs {expect}, {msg}" elif math.isinf(value): @@ -209,7 +209,7 @@ def check_tir_const_fold( if isinstance(x_range, (int, float)): x = x_range elif dtype.startswith("int") or dtype.startswith("uint"): - x = np.random.randint(x_range[0], x_range[1] + 1) + x = np.random.randint(x_range[0], x_range[1] + 1, dtype=dtype) else: x = np.random.uniform(x_range[0], x_range[1]) @@ -218,7 +218,7 @@ def check_tir_const_fold( if isinstance(y_range, (int, float)): y = y_range elif dtype.startswith("int") or dtype.startswith("uint"): - y = np.random.randint(y_range[0], y_range[1] + 1) + y = np.random.randint(y_range[0], y_range[1] + 1, dtype=dtype) else: y = np.random.uniform(y_range[0], y_range[1]) @@ -239,9 +239,14 @@ def check_tir_const_fold( + "This test is intentionally non-deterministic, " + f"if it fails please report it in github issue together with this seed {seed}\n" ) - compare_float_value(calc_res, fold_res.value, flaky_msg) - if expect: - compare_float_value(expect, calc_res, flaky_msg) + if dtype.startswith("float"): + compare_float_value(calc_res, fold_res.value, flaky_msg) + if expect: + compare_float_value(expect, calc_res, flaky_msg) + else: + assert calc_res == fold_res.value, flaky_msg + if expect: + assert expect == calc_res, flaky_msg @tvm.testing.requires_llvm() @@ -249,25 +254,35 @@ def test_tir_floatimm_const_fold(): """Behavior check: folding fp32 match platform f32 arithmetic""" @T.prim_func - def float_imm_multiply(x: T.float32, y: T.float32) -> T.float32: - T.evaluate(T.ret(x * y, dtype="float32")) + def float_imm_multiply(x: T.float32, y: T.float32, z: T.Buffer[(), "float32"]): + z[()] = x * y @T.prim_func - def float_imm_add(x: T.float32, y: T.float32) -> T.float32: - T.evaluate(T.ret(x + y, dtype="float32")) + def float_imm_add(x: T.float32, y: T.float32, z: T.Buffer[(), "float32"]): + z[()] = x + y @T.prim_func - def float_imm_sub(x: T.float32, y: T.float32) -> T.float32: - T.evaluate(T.ret(x - y, dtype="float32")) + def float_imm_sub(x: T.float32, y: T.float32, z: T.Buffer[(), "float32"]): + z[()] = x - y @T.prim_func - def float_imm_div(x: T.float32, y: T.float32) -> T.float32: - T.evaluate(T.ret(x / y, dtype="float32")) + def float_imm_div(x: T.float32, y: T.float32, z: T.Buffer[(), "float32"]): + z[()] = x / y + + def __wrap_build(f): + lib = tvm.build(f, target="llvm") + z = tvm.nd.array(np.zeros([]).astype("float32")) + + def _func(x, y): + lib(x, y, z) + return z.numpy() + + return _func - fmul = tvm.build(float_imm_multiply, target="llvm") - fadd = tvm.build(float_imm_add, target="llvm") - fsub = tvm.build(float_imm_sub, target="llvm") - fdiv = tvm.build(float_imm_div, target="llvm") + fmul = __wrap_build(float_imm_multiply) + fadd = __wrap_build(float_imm_add) + fsub = __wrap_build(float_imm_sub) + fdiv = __wrap_build(float_imm_div) # overflow check_tir_const_fold("float32", lambda x, y: x * y, fmul, 3.0e30, 3.0e30, np.inf)