Skip to content

Commit

Permalink
[Bugfix] [tir] do not simplify 'Any() - Any()' to 0 (apache#8266)
Browse files Browse the repository at this point in the history
* fix

* fix lint

* remove

* address comments
  • Loading branch information
hgt312 authored and ylc committed Sep 29, 2021
1 parent 5de98a2 commit 175fbc2
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/tvm/tir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,7 +1214,7 @@ def __init__(self, var, value, body, span=None):


@tvm._ffi.register_object("tir.Any")
class Any(PrimExpr):
class Any(PrimExprWithOp):
"""Any node.
span : Optional[Span]
Expand Down
3 changes: 3 additions & 0 deletions src/tir/analysis/deep_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
auto* prhs = rhs.as<IntImmNode>();
return plhs->dtype == prhs->dtype && plhs->value == prhs->value;
}
if (lhs.as<AnyNode>()) {
return false;
}
return DeepCmpSEqualHandler().SEqualReduce(lhs, rhs, false);
}

Expand Down
3 changes: 3 additions & 0 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def test_add_index_simplify():
def test_sub_index_simplify():
ck = RewriteChecker()
x, y, z = te.var("x"), te.var("y"), te.var("z")
a, b = tvm.tir.Any(), tvm.tir.Any()

ck.verify(x + y - y, x)
ck.verify(x + y - x, y)
Expand All @@ -293,6 +294,8 @@ def test_sub_index_simplify():

# mul co-efficient foldng
ck.verify(x - x, 0)
ck.verify(a - a, 0)
ck.verify(a - b, a - b)
ck.verify(x * y - x, x * (y + (-1)))
ck.verify(x * y - 10 * x, x * (y + (-10)))
ck.verify(y * x - x * z, x * (y - z))
Expand Down

0 comments on commit 175fbc2

Please sign in to comment.