From d28859d2838493ad816271c4d2a8f7d1c4227ac5 Mon Sep 17 00:00:00 2001 From: hgt312 Date: Thu, 17 Jun 2021 00:36:59 +0800 Subject: [PATCH 1/4] fix --- python/tvm/tir/expr.py | 2 +- src/tir/analysis/deep_equal.cc | 3 +++ tests/python/unittest/test_arith_rewrite_simplify.py | 4 ++++ 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 286e4051da51..82e40d01412f 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -1190,7 +1190,7 @@ def __init__(self, var, value, body, span=None): @tvm._ffi.register_object("tir.Any") -class Any(PrimExpr): +class Any(PrimExprWithOp): """Any node. span : Optional[Span] diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc index 7eb8013f2a85..19a632e1412b 100644 --- a/src/tir/analysis/deep_equal.cc +++ b/src/tir/analysis/deep_equal.cc @@ -59,6 +59,9 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { auto* prhs = rhs.as(); return plhs->dtype == prhs->dtype && plhs->value == prhs->value; } + if (auto* plhs = lhs.as()) { + return lhs.same_as(rhs); + } return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false); } diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index c3afa6c65627..b93c5c1aa92a 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -275,6 +275,7 @@ def test_add_index_simplify(): def test_sub_index_simplify(): ck = RewriteChecker() x, y, z = te.var("x"), te.var("y"), te.var("z") + a, b, c = tvm.tir.Any(), tvm.tir.Any(), tvm.tir.Any() ck.verify(x + y - y, x) ck.verify(x + y - x, y) @@ -293,6 +294,9 @@ def test_sub_index_simplify(): # mul co-efficient foldng ck.verify(x - x, 0) + ck.verify(a - a, 0) + ck.verify(a - b, a - b) + ck.verify(a - b, c - c) ck.verify(x * y - x, x * (y + (-1))) ck.verify(x * y - 10 * x, x * (y + (-10))) ck.verify(y * x - x * z, x * (y - z)) From 60fb0d389a04a44cc4556470988a1d8f0dea080e Mon Sep 17 00:00:00 2001 From: hgt312 Date: Fri, 18 Jun 2021 19:38:37 +0800 Subject: [PATCH 2/4] fix lint --- src/tir/analysis/deep_equal.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc index 19a632e1412b..87283451f910 100644 --- a/src/tir/analysis/deep_equal.cc +++ b/src/tir/analysis/deep_equal.cc @@ -59,7 +59,7 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { auto* prhs = rhs.as(); return plhs->dtype == prhs->dtype && plhs->value == prhs->value; } - if (auto* plhs = lhs.as()) { + if (lhs.as()) { return lhs.same_as(rhs); } return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false); From 6cc86c88561323e42dc9a453637220da755f9f14 Mon Sep 17 00:00:00 2001 From: hgt312 Date: Mon, 21 Jun 2021 16:40:50 +0800 Subject: [PATCH 3/4] remove --- tests/python/unittest/test_arith_rewrite_simplify.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index b93c5c1aa92a..1ffd4ebdf26a 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -296,7 +296,6 @@ def test_sub_index_simplify(): ck.verify(x - x, 0) ck.verify(a - a, 0) ck.verify(a - b, a - b) - ck.verify(a - b, c - c) ck.verify(x * y - x, x * (y + (-1))) ck.verify(x * y - 10 * x, x * (y + (-10))) ck.verify(y * x - x * z, x * (y - z)) From b720a8290839daf422a212cb607112cd22622c6e Mon Sep 17 00:00:00 2001 From: hgt312 Date: Tue, 29 Jun 2021 12:19:27 +0800 Subject: [PATCH 4/4] address comments --- src/tir/analysis/deep_equal.cc | 2 +- tests/python/unittest/test_arith_rewrite_simplify.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc index 87283451f910..7f48cc439234 100644 --- a/src/tir/analysis/deep_equal.cc +++ b/src/tir/analysis/deep_equal.cc @@ -60,7 +60,7 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { return plhs->dtype == prhs->dtype && plhs->value == prhs->value; } if (lhs.as()) { - return lhs.same_as(rhs); + return false; } return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false); } diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 1ffd4ebdf26a..231c376c50ca 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -275,7 +275,7 @@ def test_add_index_simplify(): def test_sub_index_simplify(): ck = RewriteChecker() x, y, z = te.var("x"), te.var("y"), te.var("z") - a, b, c = tvm.tir.Any(), tvm.tir.Any(), tvm.tir.Any() + a, b = tvm.tir.Any(), tvm.tir.Any() ck.verify(x + y - y, x) ck.verify(x + y - x, y)