diff --git a/python/tvm/expr.py b/python/tvm/expr.py index 7575037ebc11..750e78909c7a 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -152,6 +152,9 @@ class EqualOp(NodeGeneric, ExprOp): b : Expr Right operand. """ + # This class is not manipulated by C++. So use python's identity check function is sufficient + same_as = object.__eq__ + def __init__(self, a, b): self.a = a self.b = b @@ -181,6 +184,9 @@ class NotEqualOp(NodeGeneric, ExprOp): b : Expr Right operand. """ + # This class is not manipulated by C++. So use python's identity check function is sufficient + same_as = object.__eq__ + def __init__(self, a, b): self.a = a self.b = b diff --git a/tests/python/unittest/test_lang_basic.py b/tests/python/unittest/test_lang_basic.py index dd982313daba..1461ecec100f 100644 --- a/tests/python/unittest/test_lang_basic.py +++ b/tests/python/unittest/test_lang_basic.py @@ -134,6 +134,14 @@ def test_bitwise(): assert str(~x) == 'bitwise_not(x)' +def test_equality(): + a = tvm.var('a') + b = tvm.var('b') + c = (a == b) + assert not c + d = (c != c) + assert not d + if __name__ == "__main__": test_cast() test_attr() @@ -148,3 +156,4 @@ def test_bitwise(): test_any() test_all() test_bitwise() + test_equality()