diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst index 992954c9a5b1..d77a51980f23 100644 --- a/docs/langref/relay_pattern.rst +++ b/docs/langref/relay_pattern.rst @@ -246,6 +246,24 @@ are matched: assert pat.match(relay.expr.If(cond, x, y)) + +A Relay ``Let`` expression can be matched if all of its variable, value, and body +are matched: + +.. code-block:: python + + def test_match_let(): + x = is_var("x") + y = is_var("y") + let_var = is_var("let") + pat = is_let(let_var, is_op("less")(x, y), let_var) + + x = relay.var("x") + y = relay.var("y") + lv = relay.var("let") + cond = x < y + assert pat.match(relay.expr.Let(lv, cond, lv)) + Matching Diamonds and Post-Dominator Graphs ******************************************* @@ -310,6 +328,7 @@ The high level design is to introduce a language of patterns for now we propose | is_tuple() | is_tuple_get_item(pattern, index = None) | is_if(cond, tru, fls) + | is_let(var, value, body) | pattern1 `|` pattern2 | dominates(parent_pattern, path_pattern, child_pattern) | FunctionPattern(params, body) @@ -367,6 +386,16 @@ Function Pattern Match a Function with a body and parameters +If Pattern +********** + +Match an If with condition, true branch, and false branch + +Let Pattern +*********** + +Match a Let with a variable, value, and body + Applications ============ diff --git a/include/tvm/relay/dataflow_pattern.h b/include/tvm/relay/dataflow_pattern.h index 1b0c0aca7ff6..1e6cecfd041b 100644 --- a/include/tvm/relay/dataflow_pattern.h +++ b/include/tvm/relay/dataflow_pattern.h @@ -222,6 +222,42 @@ class FunctionPattern : public DFPattern { TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionPatternNode); }; +/*! \brief A binding of a sub-network. */ +class LetPatternNode : public DFPatternNode { + public: + /*! \brief The variable we bind to */ + DFPattern var; + /*! \brief The value we bind var to */ + DFPattern value; + /*! \brief The body of the let binding */ + DFPattern body; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("var", &var); + v->Visit("value", &value); + v->Visit("body", &body); + } + + static constexpr const char* _type_key = "relay.dataflow_pattern.LetPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(LetPatternNode, DFPatternNode); +}; + +/*! + * \brief Let binding that binds a local var + */ +class LetPattern : public DFPattern { + public: + /*! + * \brief The constructor + * \param var The variable that is bound to. + * \param value The value used to bind to the variable. + * \param body The body of the let binding. + */ + TVM_DLL LetPattern(DFPattern var, DFPattern value, DFPattern body); + + TVM_DEFINE_OBJECT_REF_METHODS(LetPattern, DFPattern, LetPatternNode); +}; + /*! \brief Tuple of multiple Exprs */ class TuplePattern; /*! \brief Tuple container */ diff --git a/include/tvm/relay/dataflow_pattern_functor.h b/include/tvm/relay/dataflow_pattern_functor.h index bff9e23ef046..490cdc5e3f9d 100644 --- a/include/tvm/relay/dataflow_pattern_functor.h +++ b/include/tvm/relay/dataflow_pattern_functor.h @@ -84,18 +84,19 @@ class DFPatternFunctor { virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const ConstantPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const DataTypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const FunctionPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const IfPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const LetPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const TupleGetItemPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; - virtual R VisitDFPattern_(const IfPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; - virtual R VisitDFPattern_(const ConstantPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPatternDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); @@ -115,9 +116,10 @@ class DFPatternFunctor { RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(FunctionPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(IfPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(LetPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode); - RELAY_DFPATTERN_FUNCTOR_DISPATCH(IfPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode); @@ -143,10 +145,11 @@ class DFPatternVisitor : public DFPatternFunctor { void VisitDFPattern_(const DominatorPatternNode* op) override; void VisitDFPattern_(const ExprPatternNode* op) override; void VisitDFPattern_(const FunctionPatternNode* op) override; + void VisitDFPattern_(const IfPatternNode* op) override; + void VisitDFPattern_(const LetPatternNode* op) override; void VisitDFPattern_(const ShapePatternNode* op) override; void VisitDFPattern_(const TupleGetItemPatternNode* op) override; void VisitDFPattern_(const TuplePatternNode* op) override; - void VisitDFPattern_(const IfPatternNode* op) override; void VisitDFPattern_(const TypePatternNode* op) override; void VisitDFPattern_(const VarPatternNode* op) override; void VisitDFPattern_(const WildcardPatternNode* op) override; diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index 6f764e1651da..d4a8481d106e 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -337,6 +337,29 @@ def is_if(cond, true_branch, false_branch): return IfPattern(cond, true_branch, false_branch) +def is_let(var, value, body): + """ + Syntatic sugar for creating a LetPattern. + + Parameters + ---------- + var: tvm.relay.dataflow_pattern.DFPattern + The pattern describing the variable of Let. + + value: tvm.relay.dataflow_pattern.DFPattern + The pattern describing the value of Let. + + body: tvm.relay.dataflow_pattern.DFPattern + The pattern describing the body where the binding is in effect. + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting pattern. + """ + return LetPattern(var, value, body) + + def wildcard() -> "DFPattern": """ Syntatic sugar for creating a WildcardPattern. @@ -579,6 +602,27 @@ def __init__(self, cond: "DFPattern", true_branch: "DFPattern", false_branch: "D self.__init_handle_by_constructor__(ffi.IfPattern, cond, true_branch, false_branch) +@register_df_node +class LetPattern(DFPattern): + """A patern matching a Relay Let. + + Parameters + ---------- + var: tvm.relay.dataflow_pattern.DFPattern + The pattern describing the variable of Let. + + value: tvm.relay.dataflow_pattern.DFPattern + The pattern describing the value of Let. + + body: tvm.relay.dataflow_pattern.DFPattern + The pattern describing the body where the binding is in effect. + + """ + + def __init__(self, var: "DFPattern", value: "DFPattern", body: "DFPattern"): + self.__init_handle_by_constructor__(ffi.LetPattern, var, value, body) + + @register_df_node class TuplePattern(DFPattern): """A patern matching a Relay Tuple. diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 459694b8f679..0d9481312137 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -55,10 +55,11 @@ class DFPatternMatcher : public DFPatternFunctor()) { + return VisitDFPattern(op->var, let_node->var) && VisitDFPattern(op->value, let_node->value) && + VisitDFPattern(op->body, let_node->body); + } + return false; +} + Expr InferType(const Expr& expr) { auto mod = IRModule::FromExpr(expr); mod = transform::InferType()(mod); diff --git a/src/relay/ir/dataflow_pattern.cc b/src/relay/ir/dataflow_pattern.cc index 1e268fb00d97..4c3b82cc19d4 100644 --- a/src/relay/ir/dataflow_pattern.cc +++ b/src/relay/ir/dataflow_pattern.cc @@ -112,6 +112,28 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "FunctionPatternNode(" << node->params << ", " << node->body << ")"; }); +LetPattern::LetPattern(DFPattern var, DFPattern value, DFPattern body) { + ObjectPtr n = make_object(); + n->var = std::move(var); + n->value = std::move(value); + n->body = std::move(body); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(LetPatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.LetPattern") + .set_body_typed([](DFPattern var, DFPattern value, DFPattern body) { + return LetPattern(var, value, body); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "LetPatternNode(" << node->var << ", " << node->value << ", " << node->body + << ")"; + }); + IfPattern::IfPattern(DFPattern cond, DFPattern true_branch, DFPattern false_branch) { ObjectPtr n = make_object(); n->cond = std::move(cond); diff --git a/src/relay/ir/dataflow_pattern_functor.cc b/src/relay/ir/dataflow_pattern_functor.cc index 25b247306229..828e867b332c 100644 --- a/src/relay/ir/dataflow_pattern_functor.cc +++ b/src/relay/ir/dataflow_pattern_functor.cc @@ -87,6 +87,12 @@ void DFPatternVisitor::VisitDFPattern_(const IfPatternNode* op) { VisitDFPattern(op->false_branch); } +void DFPatternVisitor::VisitDFPattern_(const LetPatternNode* op) { + VisitDFPattern(op->var); + VisitDFPattern(op->value); + VisitDFPattern(op->body); +} + void DFPatternVisitor::VisitDFPattern_(const TypePatternNode* op) { VisitDFPattern(op->pattern); } void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {} diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc index 9ee5c9cf6b85..0f81c2360d0f 100644 --- a/src/relay/ir/indexed_graph.cc +++ b/src/relay/ir/indexed_graph.cc @@ -288,6 +288,12 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { VisitDFPattern(op->false_branch, graph_.node_map_[GetRef(op)]); } + void VisitDFPattern_(const LetPatternNode* op, NodePtr parent) override { + VisitDFPattern(op->var, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->value, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->body, graph_.node_map_[GetRef(op)]); + } + void VisitDFPattern_(const TypePatternNode* op, NodePtr parent) override { VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); } diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 934ebf462b95..e7b367b8f631 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -138,6 +138,18 @@ def test_IfPattern(): assert isinstance(pat.false_branch, VarPattern) +def test_LetPattern(): + x = is_var("x") + y = is_var("y") + let_var = is_var("let") + pat = is_let(let_var, is_op("less")(x, y), let_var) + + assert isinstance(pat, LetPattern) + assert isinstance(pat.var, VarPattern) + assert isinstance(pat.value, CallPattern) + assert isinstance(pat.body, VarPattern) + + ## MATCHER TESTS @@ -233,6 +245,33 @@ def test_no_match_if(): assert not pat.match(relay.expr.If(x < y, y, x)) +def test_match_let(): + x = is_var("x") + y = is_var("y") + let_var = is_var("let") + pat = is_let(let_var, is_op("less")(x, y), let_var) + + x = relay.var("x") + y = relay.var("y") + lv = relay.var("let") + cond = x < y + assert pat.match(relay.expr.Let(lv, cond, lv)) + + +def test_no_match_let(): + x = is_var("x") + y = is_var("y") + let_var = is_var("let") + pat = is_let(let_var, is_op("less")(x, y), let_var) + + x = relay.var("x") + y = relay.var("y") + lv = relay.var("let") + + assert not pat.match(relay.expr.Let(lv, x > y, lv)) + assert not pat.match(relay.expr.Let(lv, x < y, lv * x)) + + def test_match_option(): x = relay.var("x") w = relay.var("w")