diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index b145fa8b9532..4ec0d45a4752 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -370,12 +370,13 @@ class RelayExprNode : public BaseExprNode { mutable Type checked_type_ = Type(nullptr); /*! - * \brief Stores the result of static shape analysis. + * \brief Stores the result of static shape analysis. It must be a RelayExpr + * and ObjectRef is used here to avoid cyclic typing. * * \note The value will be optional if a static shape can not be inferred. * use .shape() instead to acesss an always defined shape expression. */ - Optional> shape_ = Optional>(); + mutable Optional shape_ = Optional(); /*! * \return The checked_type @@ -387,7 +388,7 @@ class RelayExprNode : public BaseExprNode { * * Only valid when the expression's type is a Tensor. */ - inline RelayExpr shape() const; + RelayExpr shape() const; /*! * \brief Check if the inferred(checked) type of the Expr diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index bf883e6d2f4a..8cb74a595c67 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -31,12 +31,12 @@ namespace tvm { namespace relax { +using Expr = RelayExpr; +using ExprNode = RelayExprNode; using relay::Id; using relay::Call; using relay::Tuple; using relay::TupleGetItem; -using ExprNode = RelayExprNode; -using Expr = RelayExpr; /*! \brief A shape expression which allows users to construct a shape containing PrimExpr. */ @@ -121,13 +121,13 @@ class VarNode : public ExprNode { class Var : public Expr { public: TVM_DLL Var(String name_hint, - runtime::Optional> shape_annotation, + runtime::Optional shape_annotation, runtime::Optional type_annotation, Span span = Span()) : Var(Id(name_hint), shape_annotation, type_annotation, span) {} TVM_DLL Var(Id vid, - runtime::Optional> shape_annotation, + runtime::Optional shape_annotation, runtime::Optional type_annotation, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Var, Expr, VarNode); diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 3c3fefb6d6c6..4bc53e5549a1 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -51,6 +51,18 @@ def checked_type(self): raise ValueError("The type checker has not populated" " the checked_type for this node") return ret + @property + def shape(self): + """Get the shape of tvm.relay.Expr. + + Returns + ------- + shape : tvm.ir.RelayExpr + The expression that represents the shape. + """ + return _ffi_api.RelayExprShape(self) + + @tvm._ffi.register_object("GlobalVar") class GlobalVar(RelayExpr): diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 3681d9582685..172cf6dee46e 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -37,6 +37,20 @@ class ShapeExpr(Expr): def __init__(self, values: List[PrimExpr]) -> None: self.__init_handle_by_constructor__(_ffi_api.ShapeExpr, values) + def __getitem__(self, index): + if index >= len(self): + raise IndexError("Tuple index out of range") + return self.values[index] + + def __len__(self): + return len(self.values) + +def make_shape(shape: List[PrimExpr]) -> ShapeExpr: + if isinstance(shape, (list, tuple)): + return ShapeExpr(shape) + else: + raise ValueError + @tvm._ffi.register_object("relax.expr.Var") class Var(Expr): @@ -44,8 +58,10 @@ class Var(Expr): type_annotation: Optional[Type] def __init__(self, name_hint: str, - shape_annotation: Optional[List[Type]] = None, + shape_annotation: Optional[Expr] = None, type_annotation: Optional[Type] = None) -> None: + if shape_annotation is not None: + shape_annotation = make_shape(shape_annotation) self.__init_handle_by_constructor__(_ffi_api.Var, name_hint, shape_annotation, type_annotation) diff --git a/src/relax/expr.cc b/src/relax/expr.cc index 324df1895c64..7a4fa8a30826 100644 --- a/src/relax/expr.cc +++ b/src/relax/expr.cc @@ -19,11 +19,25 @@ #include namespace tvm { + +RelayExpr RelayExprNode::shape() const { + if (this->shape_.defined()) { + return Downcast(this->shape_); + } + static const Op& op = Op::Get("relax.shape_of"); + RelayExpr self = GetRef(this); + return relay::Call(op, {self}, {}, {}); +} + +TVM_REGISTER_GLOBAL("ir.RelayExprShape") +.set_body_typed([](RelayExpr expr) { + return expr->shape(); +}); + namespace relax { using tvm::runtime::Optional; - TVM_REGISTER_NODE_TYPE(ShapeExprNode); ShapeExpr::ShapeExpr(Array values) { @@ -41,7 +55,7 @@ TVM_REGISTER_GLOBAL("relax.ShapeExpr") TVM_REGISTER_NODE_TYPE(VarNode); Var::Var(Id vid, - Optional> shape_annotation, + Optional shape_annotation, Optional type_annotation, Span span) { ObjectPtr n = make_object(); @@ -54,7 +68,7 @@ Var::Var(Id vid, TVM_REGISTER_GLOBAL("relax.Var") .set_body_typed([](String name_hint, - Optional> shape_annotation, + Optional shape_annotation, Optional type_annotation) { return Var(name_hint, shape_annotation, type_annotation); }); @@ -64,7 +78,7 @@ TVM_REGISTER_NODE_TYPE(DataflowVarNode); TVM_REGISTER_GLOBAL("relax.DataflowVar") .set_body_typed([](String name_hint, - Optional> shape_annotation, + Optional shape_annotation, Optional type_annotation) { return DataflowVar(name_hint, shape_annotation, type_annotation); }); diff --git a/src/relax/op.cc b/src/relax/op.cc index c3c8ed232917..676dfb22a95c 100644 --- a/src/relax/op.cc +++ b/src/relax/op.cc @@ -22,19 +22,34 @@ namespace tvm { namespace relax { +// call_dps + +RELAY_REGISTER_OP("relax.call_dps") +.set_num_inputs(3) +.add_argument("shape", "ShapeExpr", "The output shape.") +.add_argument("func", "Expr", "The destination-passing-style function.") +.add_argument("args", "Tuple", "The input arguments."); + Expr MakeCallDPS(ShapeExpr shape, Expr func, Tuple args) { - static const Op& op = Op::Get("call_dps"); + static const Op& op = Op::Get("relax.call_dps"); return Call(op, {shape, func, args}, {}, {}); } TVM_REGISTER_GLOBAL("relax.op.call_dps") .set_body_typed(MakeCallDPS); -RELAY_REGISTER_OP("call_dps") -.set_num_inputs(3) -.add_argument("shape", "ShapeExpr", "The output shape.") -.add_argument("func", "Expr", "The destination-passing-style function.") -.add_argument("args", "Tuple", "The input arguments."); +// shape_of + +RELAY_REGISTER_OP("relax.shape_of") +.set_num_inputs(1) +.add_argument("input", "Expr", "The input expression"); + +Expr MakeShapeOf(Expr expr) { + static const Op& op = Op::Get("relax.shape_of"); + return Call(op, {expr}, {}, {}); +} +TVM_REGISTER_GLOBAL("relax.op.shape_of") +.set_body_typed(MakeShapeOf); } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_ast.py b/tests/python/relax/test_expr.py similarity index 91% rename from tests/python/relax/test_ast.py rename to tests/python/relax/test_expr.py index 2e55909f7531..60e8bd340f5b 100644 --- a/tests/python/relax/test_ast.py +++ b/tests/python/relax/test_expr.py @@ -113,6 +113,19 @@ def test_func(): assert func.name.name_hint == "func" +def test_shape_of(): + v0 = rx.Var("v0") + s0 = v0.shape + assert isinstance(s0, tvm.relay.Call) + assert s0.op.name == "relax.shape_of" + + shape_anno = [96, 54] + v1 = rx.Var("v1", shape_anno) + s1 = v1.shape + for x, y in zip(shape_anno, s1): + assert x == y + + if __name__ == "__main__": test_var() test_dataflow_var() @@ -123,3 +136,4 @@ def test_func(): test_seq_expr() test_shape_expr() test_func() + test_shape_of()