Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PYTHON] Improve equal sugar #564

Merged
merged 2 commits into from
Oct 17, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 64 additions & 13 deletions python/tvm/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""
# pylint: disable=missing-docstring
from __future__ import absolute_import as _abs
from ._ffi.node import NodeBase, register_node
from ._ffi.node import NodeBase, NodeGeneric, register_node
from . import make as _make
from . import _api_internal

Expand Down Expand Up @@ -89,10 +89,10 @@ def __le__(self, other):
return _make.LE(self, other)

def __eq__(self, other):
return self.equal(other)
return EqualOp(self, other)

def __ne__(self, other):
return _make.NE(self, other)
return NotEqualOp(self, other)

def __gt__(self, other):
return _make.GT(self, other)
Expand Down Expand Up @@ -138,12 +138,71 @@ def astype(self, dtype):
return _make.static_cast(dtype, self)


class EqualOp(NodeGeneric, ExprOp):
"""Deferred equal operator.

This is used to support sugar that a == b can either
mean NodeBase.same_as or NodeBase.equal.

Parameters
----------
a : Expr
Left operand.

b : Expr
Right operand.
"""
def __init__(self, a, b):
self.a = a
self.b = b

def __nonzero__(self):
return self.a.same_as(self.b)

def __bool__(self):
return self.__nonzero__()

def asnode(self):
"""Convert node."""
return _make.EQ(self.a, self.b)


class NotEqualOp(NodeGeneric, ExprOp):
"""Deferred NE operator.

This is used to support sugar that a != b can either
mean not NodeBase.same_as or make.NE.

Parameters
----------
a : Expr
Left operand.

b : Expr
Right operand.
"""
def __init__(self, a, b):
self.a = a
self.b = b

def __nonzero__(self):
return not self.a.same_as(self.b)

def __bool__(self):
return self.__nonzero__()

def asnode(self):
"""Convert node."""
return _make.NE(self.a, self.b)


class Expr(ExprOp, NodeBase):
"""Base class of all tvm Expressions"""
# In Python3, We have to explicity tell interpreter to retain __hash__ if we overide __eq__
# https://docs.python.org/3.1/reference/datamodel.html#object.__hash__
__hash__ = NodeBase.__hash__


class ConstExpr(Expr):
pass

Expand Down Expand Up @@ -215,19 +274,11 @@ class Max(BinaryOpExpr):

@register_node
class EQ(CmpExpr):
def __nonzero__(self):
return self.a.same_as(self.b)

def __bool__(self):
return self.__nonzero__()
pass

@register_node
class NE(CmpExpr):
def __nonzero__(self):
return not self.a.same_as(self.b)

def __bool__(self):
return self.__nonzero__()
pass

@register_node
class LT(CmpExpr):
Expand Down
2 changes: 2 additions & 0 deletions tests/python/unittest/test_ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_if():
A[0] = A[i] + 2

body = ib.get()
assert A == A
assert isinstance(body, tvm.stmt.For)
body = body.body
assert isinstance(body, tvm.stmt.IfThenElse)
Expand All @@ -42,6 +43,7 @@ def test_prefetch():
A = tvm.placeholder((10, 20), name="A")
ib = tvm.ir_builder.create()
n = tvm.var("n")

with ib.for_range(0, n, name="i") as i:
ib.emit(
tvm.make.Prefetch(
Expand Down