Skip to content

Commit

Permalink
Add a relay LetPattern
Browse files Browse the repository at this point in the history
  • Loading branch information
mbrookhart committed Jan 23, 2021
1 parent 6787d74 commit f5662b3
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 5 deletions.
29 changes: 29 additions & 0 deletions docs/langref/relay_pattern.rst
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,24 @@ are matched:
assert pat.match(relay.expr.If(cond, x, y))
A Relay ``Let`` expression can be matched if all of its variable, value, and body
are matched:

.. code-block:: python
def test_match_let():
x = is_var("x")
y = is_var("y")
let_var = is_var("let")
pat = is_let(let_var, is_op("less")(x, y), let_var)
x = relay.var("x")
y = relay.var("y")
lv = relay.var("let")
cond = x < y
assert pat.match(relay.expr.Let(lv, cond, lv))
Matching Diamonds and Post-Dominator Graphs
*******************************************

Expand Down Expand Up @@ -310,6 +328,7 @@ The high level design is to introduce a language of patterns for now we propose
| is_tuple()
| is_tuple_get_item(pattern, index = None)
| is_if(cond, tru, fls)
| is_let(var, value, body)
| pattern1 `|` pattern2
| dominates(parent_pattern, path_pattern, child_pattern)
| FunctionPattern(params, body)
Expand Down Expand Up @@ -367,6 +386,16 @@ Function Pattern

Match a Function with a body and parameters

If Pattern
**********

Match an If with condition, true branch, and false branch

Let Pattern
***********

Match a Let with a variable, value, and body

Applications
============

Expand Down
37 changes: 37 additions & 0 deletions include/tvm/relay/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,43 @@ class FunctionPattern : public DFPattern {
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionPatternNode);
};

/*! \brief A binding of a sub-network. */
class LetPatternNode : public DFPatternNode {
public:
/*! \brief The variable we bind to */
DFPattern var;
/*! \brief The value we bind var to */
DFPattern value;
/*! \brief The body of the let binding */
DFPattern body;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("var", &var);
v->Visit("value", &value);
v->Visit("body", &body);
}

static constexpr const char* _type_key = "relay.dataflow_pattern.LetPattern";
TVM_DECLARE_FINAL_OBJECT_INFO(LetPatternNode, DFPatternNode);
};

/*!
* \brief Let binding that binds a local var
*/
class LetPattern : public DFPattern {
public:
/*!
* \brief The constructor
* \param var The variable that is bound to.
* \param value The value used to bind to the variable.
* \param body The body of the let binding.
* \param span The source span of the expression.
*/
TVM_DLL LetPattern(DFPattern var, DFPattern value, DFPattern body);

TVM_DEFINE_OBJECT_REF_METHODS(LetPattern, DFPattern, LetPatternNode);
};

/*! \brief Tuple of multiple Exprs */
class TuplePattern;
/*! \brief Tuple container */
Expand Down
11 changes: 7 additions & 4 deletions include/tvm/relay/dataflow_pattern_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,19 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const ConstantPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const DataTypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const FunctionPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const IfPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const LetPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,
Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const IfPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const ConstantPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPatternDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
Expand All @@ -115,9 +116,10 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(FunctionPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(IfPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(LetPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(IfPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
Expand All @@ -143,10 +145,11 @@ class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {
void VisitDFPattern_(const DominatorPatternNode* op) override;
void VisitDFPattern_(const ExprPatternNode* op) override;
void VisitDFPattern_(const FunctionPatternNode* op) override;
void VisitDFPattern_(const IfPatternNode* op) override;
void VisitDFPattern_(const LetPatternNode* op) override;
void VisitDFPattern_(const ShapePatternNode* op) override;
void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
void VisitDFPattern_(const TuplePatternNode* op) override;
void VisitDFPattern_(const IfPatternNode* op) override;
void VisitDFPattern_(const TypePatternNode* op) override;
void VisitDFPattern_(const VarPatternNode* op) override;
void VisitDFPattern_(const WildcardPatternNode* op) override;
Expand Down
44 changes: 44 additions & 0 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,29 @@ def is_if(cond, true_branch, false_branch):
return IfPattern(cond, true_branch, false_branch)


def is_let(var, value, body):
"""
Syntatic sugar for creating an IfPattern.
Parameters
----------
var: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the variable of Let.
value: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the value of Let.
body: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the body where the binding is in effect.
Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
The resulting pattern.
"""
return LetPattern(var, value, body)


def wildcard() -> "DFPattern":
"""
Syntatic sugar for creating a WildcardPattern.
Expand Down Expand Up @@ -579,6 +602,27 @@ def __init__(self, cond: "DFPattern", true_branch: "DFPattern", false_branch: "D
self.__init_handle_by_constructor__(ffi.IfPattern, cond, true_branch, false_branch)


@register_df_node
class LetPattern(DFPattern):
"""A patern matching a Relay If.
Parameters
----------
var: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the variable of Let.
value: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the value of Let.
body: tvm.relay.dataflow_pattern.DFPattern
The pattern describing the body where the binding is in effect.
"""

def __init__(self, var: "DFPattern", value: "DFPattern", body: "DFPattern"):
self.__init_handle_by_constructor__(ffi.LetPattern, var, value, body)


@register_df_node
class TuplePattern(DFPattern):
"""A patern matching a Relay Tuple.
Expand Down
11 changes: 10 additions & 1 deletion src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,11 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Ex
bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const IfPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const LetPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const IfPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
Expand Down Expand Up @@ -423,6 +424,14 @@ bool DFPatternMatcher::VisitDFPattern_(const IfPatternNode* op, const Expr& expr
return false;
}

bool DFPatternMatcher::VisitDFPattern_(const LetPatternNode* op, const Expr& expr) {
if (const auto* let_node = expr.as<LetNode>()) {
return VisitDFPattern(op->var, let_node->var) && VisitDFPattern(op->value, let_node->value) &&
VisitDFPattern(op->body, let_node->body);
}
return false;
}

Expr InferType(const Expr& expr) {
auto mod = IRModule::FromExpr(expr);
mod = transform::InferType()(mod);
Expand Down
22 changes: 22 additions & 0 deletions src/relay/ir/dataflow_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,28 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "FunctionPatternNode(" << node->params << ", " << node->body << ")";
});

LetPattern::LetPattern(DFPattern var, DFPattern value, DFPattern body) {
ObjectPtr<LetPatternNode> n = make_object<LetPatternNode>();
n->var = std::move(var);
n->value = std::move(value);
n->body = std::move(body);
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(LetPatternNode);

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.LetPattern")
.set_body_typed([](DFPattern var, DFPattern value, DFPattern body) {
return LetPattern(var, value, body);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LetPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const LetPatternNode*>(ref.get());
p->stream << "LetPatternNode(" << node->var << ", " << node->value << ", " << node->body
<< ")";
});

IfPattern::IfPattern(DFPattern cond, DFPattern true_branch, DFPattern false_branch) {
ObjectPtr<IfPatternNode> n = make_object<IfPatternNode>();
n->cond = std::move(cond);
Expand Down
6 changes: 6 additions & 0 deletions src/relay/ir/dataflow_pattern_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ void DFPatternVisitor::VisitDFPattern_(const IfPatternNode* op) {
VisitDFPattern(op->false_branch);
}

void DFPatternVisitor::VisitDFPattern_(const LetPatternNode* op) {
VisitDFPattern(op->var);
VisitDFPattern(op->value);
VisitDFPattern(op->body);
}

void DFPatternVisitor::VisitDFPattern_(const TypePatternNode* op) { VisitDFPattern(op->pattern); }

void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {}
Expand Down
6 changes: 6 additions & 0 deletions src/relay/ir/indexed_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,12 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const DFPattern& pattern) {
VisitDFPattern(op->false_branch, graph_.node_map_[GetRef<DFPattern>(op)]);
}

void VisitDFPattern_(const LetPatternNode* op, NodePtr parent) override {
VisitDFPattern(op->var, graph_.node_map_[GetRef<DFPattern>(op)]);
VisitDFPattern(op->value, graph_.node_map_[GetRef<DFPattern>(op)]);
VisitDFPattern(op->body, graph_.node_map_[GetRef<DFPattern>(op)]);
}

void VisitDFPattern_(const TypePatternNode* op, NodePtr parent) override {
VisitDFPattern(op->pattern, graph_.node_map_[GetRef<DFPattern>(op)]);
}
Expand Down
39 changes: 39 additions & 0 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,18 @@ def test_IfPattern():
assert isinstance(pat.false_branch, VarPattern)


def test_LetPattern():
x = is_var("x")
y = is_var("y")
let_var = is_var("let")
pat = is_let(let_var, is_op("less")(x, y), let_var)

assert isinstance(pat, LetPattern)
assert isinstance(pat.var, VarPattern)
assert isinstance(pat.value, CallPattern)
assert isinstance(pat.body, VarPattern)


## MATCHER TESTS


Expand Down Expand Up @@ -233,6 +245,33 @@ def test_no_match_if():
assert not pat.match(relay.expr.If(x < y, y, x))


def test_match_let():
x = is_var("x")
y = is_var("y")
let_var = is_var("let")
pat = is_let(let_var, is_op("less")(x, y), let_var)

x = relay.var("x")
y = relay.var("y")
lv = relay.var("let")
cond = x < y
assert pat.match(relay.expr.Let(lv, cond, lv))


def test_no_match_let():
x = is_var("x")
y = is_var("y")
let_var = is_var("let")
pat = is_let(let_var, is_op("less")(x, y), let_var)

x = relay.var("x")
y = relay.var("y")
lv = relay.var("let")

assert not pat.match(relay.expr.Let(lv, x > y, lv))
assert not pat.match(relay.expr.Let(lv, x < y, lv * x))


def test_match_option():
x = relay.var("x")
w = relay.var("w")
Expand Down

0 comments on commit f5662b3

Please sign in to comment.