From 175fbc26c4e02ecba680072b2e1a5687a7fa7bbc Mon Sep 17 00:00:00 2001 From: "Huang, Guangtai" Date: Sat, 17 Jul 2021 14:56:18 +0800 Subject: [PATCH] [Bugfix] [tir] do not simplify 'Any() - Any()' to 0 (#8266) * fix * fix lint * remove * address comments --- python/tvm/tir/expr.py | 2 +- src/tir/analysis/deep_equal.cc | 3 +++ tests/python/unittest/test_arith_rewrite_simplify.py | 3 +++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 6e86157606be..4ba8c5471b5d 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -1214,7 +1214,7 @@ def __init__(self, var, value, body, span=None): @tvm._ffi.register_object("tir.Any") -class Any(PrimExpr): +class Any(PrimExprWithOp): """Any node. span : Optional[Span] diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc index 7eb8013f2a85..7f48cc439234 100644 --- a/src/tir/analysis/deep_equal.cc +++ b/src/tir/analysis/deep_equal.cc @@ -59,6 +59,9 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { auto* prhs = rhs.as(); return plhs->dtype == prhs->dtype && plhs->value == prhs->value; } + if (lhs.as()) { + return false; + } return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false); } diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index c3afa6c65627..231c376c50ca 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -275,6 +275,7 @@ def test_add_index_simplify(): def test_sub_index_simplify(): ck = RewriteChecker() x, y, z = te.var("x"), te.var("y"), te.var("z") + a, b = tvm.tir.Any(), tvm.tir.Any() ck.verify(x + y - y, x) ck.verify(x + y - x, y) @@ -293,6 +294,8 @@ def test_sub_index_simplify(): # mul co-efficient foldng ck.verify(x - x, 0) + ck.verify(a - a, 0) + ck.verify(a - b, a - b) ck.verify(x * y - x, x * (y + (-1))) ck.verify(x * y - 10 * x, x * (y + (-10))) ck.verify(y * x - x * z, x * (y - z))