Skip to content

Commit

Permalink
[PatternLang] Add ConstantPattern (apache#5689)
Browse files Browse the repository at this point in the history
* Add ConstantPattern

* update doc
  • Loading branch information
comaniac authored and trevor-m committed Jun 18, 2020
1 parent 1e849b7 commit 73f1470
Show file tree
Hide file tree
Showing 9 changed files with 337 additions and 115 deletions.
43 changes: 43 additions & 0 deletions docs/langref/relay_pattern.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ for more use cases.

.. _tests/python/relay/test_dataflow_pattern.py: https://github.com/apache/incubator-tvm/blob/master/tests/python/relay/test_dataflow_pattern.py

.. note::

If you cannot find the corresponding pattern node to match the Relay node you want,
you are welcome to raise an issue or submit a PR to add it.

Matching One of Two Ops
***********************

Expand Down Expand Up @@ -131,6 +136,44 @@ The next example is matching a pattern of batch_norm -> get(0) -> relu:
out = relay.nn.relu(tuple_get_item_node)
pat.match(out)
The next example is matching a constant node regarding its values. This is useful to check
if a specific parameter in a subgraph has been bind or not.

.. code-block:: python
def test_match_constant():
conv2d = is_op('nn.conv2d')(wildcard(), ConstantPattern())
pattern = is_op('nn.bias_add')(conv2d, wildcard())
x = relay.var('x', shape=(1, 3, 224, 224))
w = relay.var('w', shape=(3, 3, 3, 3))
b = relay.var('b', shape=(3, ))
conv2d = relay.op.nn.conv2d(x, w)
out = relay.op.nn.bias_add(conv2d, b)
func = relay.Function([x, w, b], out)
mod = tvm.IRModule.from_expr(func)
# Two inputs of the conv2d in the graph are VarNode by default, so no match.
assert not pattern.match(mod['main'].body)
# The second input (weight) has been bind with constant values so it is now a constant node.
mod["main"] = bind_params_by_name(mod["main"],
{'w': tvm.nd.array(np.ones(shape=(3, 3, 3, 3)))})
assert pattern.match(mod['main'].body)
On the other hand, if you need to match the constant with a specific value, you can directly
use ``ExprPattern``. This could be useful for algebraic simplify.

.. code-block:: python
def test_match_plus_zero():
zero = (ExprPattern(relay.const(0)) | ExprPattern(relay.const(0.0)))
pattern = wildcard() + zero
x = relay.Var('x')
y = x + relay.const(0)
assert pattern.match(y)
The next example is matching function nodes with a specific attribute:

.. code-block:: python
Expand Down
18 changes: 18 additions & 0 deletions include/tvm/relay/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,24 @@ class VarPattern : public DFPattern {
TVM_DEFINE_OBJECT_REF_METHODS(VarPattern, DFPattern, VarPatternNode);
};

/*!
* \brief A Pattern to Match a Relay Constant
*/
class ConstantPattern;
/*! \brief Container for Constant */
class ConstantPatternNode : public DFPatternNode {
public:
void VisitAttrs(tvm::AttrVisitor* v) {}

static constexpr const char* _type_key = "relay.dataflow_pattern.ConstantPattern";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantPatternNode, DFPatternNode);
};

class ConstantPattern : public DFPattern {
public:
TVM_DEFINE_OBJECT_REF_METHODS(ConstantPattern, DFPattern, ConstantPatternNode);
};

/*!
* \brief Call corresponds to operator invocation.
* Corresponds to the operator in computational graph terminology.
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/relay/dataflow_pattern_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const ConstantPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPatternDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
Expand All @@ -111,6 +112,7 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(ConstantPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode);
return vtable;
}
Expand All @@ -134,6 +136,7 @@ class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {
void VisitDFPattern_(const TuplePatternNode* op) override;
void VisitDFPattern_(const TypePatternNode* op) override;
void VisitDFPattern_(const VarPatternNode* op) override;
void VisitDFPattern_(const ConstantPatternNode* op) override;
void VisitDFPattern_(const WildcardPatternNode* op) override;

protected:
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,14 @@ def __init__(self, name_hint: str, type_annotation=None):
ffi.VarPattern, name_hint, type_annotation)


@register_df_node
class ConstantPattern(DFPattern):
"""A pattern matching a Relay Constant.
"""
def __init__(self):
self.__init_handle_by_constructor__(ffi.ConstantPattern)


@register_df_node
class CallPattern(DFPattern):
"""A pattern matching a function call node in Relay.
Expand Down
5 changes: 5 additions & 0 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Ex
bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;

void ClearMap(size_t watermark);
Expand Down Expand Up @@ -394,6 +395,10 @@ bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& exp
return matches;
}

bool DFPatternMatcher::VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) {
return expr.as<ConstantNode>() != nullptr;
}

bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) {
return true;
}
Expand Down
12 changes: 12 additions & 0 deletions src/relay/ir/dataflow_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ")";
});

TVM_REGISTER_NODE_TYPE(ConstantPatternNode);

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ConstantPattern").set_body_typed([]() {
auto c = ConstantPattern(make_object<ConstantPatternNode>());
return c;
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ConstantPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
p->stream << "ConstantPattern()";
});

CallPattern::CallPattern(DFPattern op, Array<DFPattern> args, Attrs attrs, Array<Type> type_args) {
ObjectPtr<CallPatternNode> n = make_object<CallPatternNode>();
n->op = std::move(op);
Expand Down
2 changes: 2 additions & 0 deletions src/relay/ir/dataflow_pattern_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ void DFPatternVisitor::VisitDFPattern_(const TypePatternNode* op) { VisitDFPatte

void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {}

void DFPatternVisitor::VisitDFPattern_(const ConstantPatternNode* op) {}

void DFPatternVisitor::VisitDFPattern_(const WildcardPatternNode* op) {}

} // namespace relay
Expand Down
2 changes: 2 additions & 0 deletions src/relay/ir/indexed_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const DFPattern& pattern) {

void VisitDFPattern_(const VarPatternNode* op, NodePtr parent) override {}

void VisitDFPattern_(const ConstantPatternNode* op, NodePtr parent) override {}

void VisitDFPattern_(const WildcardPatternNode* op, NodePtr parent) override {}
};
return Annotator(Creator().CreateGraph(pattern)).Annotate();
Expand Down
Loading

0 comments on commit 73f1470

Please sign in to comment.