Skip to content

Commit

Permalink
Update AST and Shape() implementation (apache#5)
Browse files Browse the repository at this point in the history
* Update AST.

* ShapeOf.

* ShapeOf.

* Address comment.
  • Loading branch information
ZihengJiang authored and junrushao committed Feb 9, 2023
1 parent 55b757f commit 83e5d23
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 18 deletions.
7 changes: 4 additions & 3 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -370,12 +370,13 @@ class RelayExprNode : public BaseExprNode {
mutable Type checked_type_ = Type(nullptr);

/*!
* \brief Stores the result of static shape analysis.
* \brief Stores the result of static shape analysis. It must be a RelayExpr
* and ObjectRef is used here to avoid cyclic typing.
*
* \note The value will be optional if a static shape can not be inferred.
* use .shape() instead to acesss an always defined shape expression.
*/
Optional<Array<PrimExpr>> shape_ = Optional<Array<PrimExpr>>();
mutable Optional<ObjectRef> shape_ = Optional<ObjectRef>();

/*!
* \return The checked_type
Expand All @@ -387,7 +388,7 @@ class RelayExprNode : public BaseExprNode {
*
* Only valid when the expression's type is a Tensor.
*/
inline RelayExpr shape() const;
RelayExpr shape() const;

/*!
* \brief Check if the inferred(checked) type of the Expr
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
namespace tvm {
namespace relax {

using Expr = RelayExpr;
using ExprNode = RelayExprNode;
using relay::Id;
using relay::Call;
using relay::Tuple;
using relay::TupleGetItem;
using ExprNode = RelayExprNode;
using Expr = RelayExpr;

/*! \brief A shape expression which allows users to construct a shape containing PrimExpr.
*/
Expand Down Expand Up @@ -121,13 +121,13 @@ class VarNode : public ExprNode {
class Var : public Expr {
public:
TVM_DLL Var(String name_hint,
runtime::Optional<Array<PrimExpr>> shape_annotation,
runtime::Optional<Expr> shape_annotation,
runtime::Optional<Type> type_annotation,
Span span = Span())
: Var(Id(name_hint), shape_annotation, type_annotation, span) {}

TVM_DLL Var(Id vid,
runtime::Optional<Array<PrimExpr>> shape_annotation,
runtime::Optional<Expr> shape_annotation,
runtime::Optional<Type> type_annotation,
Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Var, Expr, VarNode);
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/ir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ def checked_type(self):
raise ValueError("The type checker has not populated" " the checked_type for this node")
return ret

@property
def shape(self):
"""Get the shape of tvm.relay.Expr.
Returns
-------
shape : tvm.ir.RelayExpr
The expression that represents the shape.
"""
return _ffi_api.RelayExprShape(self)



@tvm._ffi.register_object("GlobalVar")
class GlobalVar(RelayExpr):
Expand Down
18 changes: 17 additions & 1 deletion python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,31 @@ class ShapeExpr(Expr):
def __init__(self, values: List[PrimExpr]) -> None:
self.__init_handle_by_constructor__(_ffi_api.ShapeExpr, values)

def __getitem__(self, index):
if index >= len(self):
raise IndexError("Tuple index out of range")
return self.values[index]

def __len__(self):
return len(self.values)

def make_shape(shape: List[PrimExpr]) -> ShapeExpr:
if isinstance(shape, (list, tuple)):
return ShapeExpr(shape)
else:
raise ValueError


@tvm._ffi.register_object("relax.expr.Var")
class Var(Expr):
id: Id
type_annotation: Optional[Type]

def __init__(self, name_hint: str,
shape_annotation: Optional[List[Type]] = None,
shape_annotation: Optional[Expr] = None,
type_annotation: Optional[Type] = None) -> None:
if shape_annotation is not None:
shape_annotation = make_shape(shape_annotation)
self.__init_handle_by_constructor__(_ffi_api.Var, name_hint,
shape_annotation,
type_annotation)
Expand Down
22 changes: 18 additions & 4 deletions src/relax/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,25 @@
#include <tvm/relax/expr.h>

namespace tvm {

RelayExpr RelayExprNode::shape() const {
if (this->shape_.defined()) {
return Downcast<RelayExpr>(this->shape_);
}
static const Op& op = Op::Get("relax.shape_of");
RelayExpr self = GetRef<RelayExpr>(this);
return relay::Call(op, {self}, {}, {});
}

TVM_REGISTER_GLOBAL("ir.RelayExprShape")
.set_body_typed([](RelayExpr expr) {
return expr->shape();
});

namespace relax {

using tvm::runtime::Optional;


TVM_REGISTER_NODE_TYPE(ShapeExprNode);

ShapeExpr::ShapeExpr(Array<PrimExpr> values) {
Expand All @@ -41,7 +55,7 @@ TVM_REGISTER_GLOBAL("relax.ShapeExpr")
TVM_REGISTER_NODE_TYPE(VarNode);

Var::Var(Id vid,
Optional<Array<PrimExpr>> shape_annotation,
Optional<Expr> shape_annotation,
Optional<Type> type_annotation,
Span span) {
ObjectPtr<VarNode> n = make_object<VarNode>();
Expand All @@ -54,7 +68,7 @@ Var::Var(Id vid,

TVM_REGISTER_GLOBAL("relax.Var")
.set_body_typed([](String name_hint,
Optional<Array<PrimExpr>> shape_annotation,
Optional<Expr> shape_annotation,
Optional<Type> type_annotation) {
return Var(name_hint, shape_annotation, type_annotation);
});
Expand All @@ -64,7 +78,7 @@ TVM_REGISTER_NODE_TYPE(DataflowVarNode);

TVM_REGISTER_GLOBAL("relax.DataflowVar")
.set_body_typed([](String name_hint,
Optional<Array<PrimExpr>> shape_annotation,
Optional<Expr> shape_annotation,
Optional<Type> type_annotation) {
return DataflowVar(name_hint, shape_annotation, type_annotation);
});
Expand Down
27 changes: 21 additions & 6 deletions src/relax/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,34 @@
namespace tvm {
namespace relax {

// call_dps

RELAY_REGISTER_OP("relax.call_dps")
.set_num_inputs(3)
.add_argument("shape", "ShapeExpr", "The output shape.")
.add_argument("func", "Expr", "The destination-passing-style function.")
.add_argument("args", "Tuple", "The input arguments.");

Expr MakeCallDPS(ShapeExpr shape, Expr func, Tuple args) {
static const Op& op = Op::Get("call_dps");
static const Op& op = Op::Get("relax.call_dps");
return Call(op, {shape, func, args}, {}, {});
}

TVM_REGISTER_GLOBAL("relax.op.call_dps")
.set_body_typed(MakeCallDPS);

RELAY_REGISTER_OP("call_dps")
.set_num_inputs(3)
.add_argument("shape", "ShapeExpr", "The output shape.")
.add_argument("func", "Expr", "The destination-passing-style function.")
.add_argument("args", "Tuple", "The input arguments.");
// shape_of

RELAY_REGISTER_OP("relax.shape_of")
.set_num_inputs(1)
.add_argument("input", "Expr", "The input expression");

Expr MakeShapeOf(Expr expr) {
static const Op& op = Op::Get("relax.shape_of");
return Call(op, {expr}, {}, {});
}

TVM_REGISTER_GLOBAL("relax.op.shape_of")
.set_body_typed(MakeShapeOf);
} // namespace relax
} // namespace tvm
14 changes: 14 additions & 0 deletions tests/python/relax/test_ast.py → tests/python/relax/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,19 @@ def test_func():
assert func.name.name_hint == "func"


def test_shape_of():
v0 = rx.Var("v0")
s0 = v0.shape
assert isinstance(s0, tvm.relay.Call)
assert s0.op.name == "relax.shape_of"

shape_anno = [96, 54]
v1 = rx.Var("v1", shape_anno)
s1 = v1.shape
for x, y in zip(shape_anno, s1):
assert x == y


if __name__ == "__main__":
test_var()
test_dataflow_var()
Expand All @@ -123,3 +136,4 @@ def test_func():
test_seq_expr()
test_shape_expr()
test_func()
test_shape_of()

0 comments on commit 83e5d23

Please sign in to comment.