Skip to content

Commit

Permalink
[PYTHON] Make IntImm more like an integer (#5232)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Apr 3, 2020
1 parent 7de8a53 commit 9b274cb
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
16 changes: 16 additions & 0 deletions python/tvm/tir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ def __init__(self, dtype, value):
self.__init_handle_by_constructor__(
tvm.ir._ffi_api.FloatImm, dtype, value)


@tvm._ffi.register_object
class IntImm(ConstExpr):
"""Int constant.
Expand All @@ -455,9 +456,24 @@ def __init__(self, dtype, value):
self.__init_handle_by_constructor__(
tvm.ir._ffi_api.IntImm, dtype, value)

def __hash__(self):
return self.value

def __int__(self):
return self.value

def __nonzero__(self):
return self.value != 0

def __eq__(self, other):
return _ffi_api._OpEQ(self, other)

def __ne__(self, other):
return _ffi_api._OpNE(self, other)

def __bool__(self):
return self.__nonzero__()


@tvm._ffi.register_object
class StringImm(ConstExpr):
Expand Down
14 changes: 14 additions & 0 deletions tests/python/unittest/test_tir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,21 @@ def test_buffer_load_store():
assert isinstance(s, tvm.tir.BufferStore)


def test_intimm_cond():
x = tvm.runtime.convert(1)
y = tvm.runtime.convert(1)
s = {x}
assert y in s
assert x == y
assert x < 20
assert not (x >= 20)
assert x < 10 and y < 10
assert not tvm.runtime.convert(x != 1)
assert x == 1


if __name__ == "__main__":
test_intimm_cond()
test_buffer_load_store()
test_vars()
test_prim_func()
Expand Down

0 comments on commit 9b274cb

Please sign in to comment.