Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PatternLang] Add ConstantPattern #5689

Merged
merged 2 commits into from
May 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions docs/langref/relay_pattern.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ for more use cases.

.. _tests/python/relay/test_dataflow_pattern.py: https://github.com/apache/incubator-tvm/blob/master/tests/python/relay/test_dataflow_pattern.py

.. note::

If you cannot find the corresponding pattern node to match the Relay node you want,
you are welcome to raise an issue or submit a PR to add it.

Matching One of Two Ops
***********************

Expand Down Expand Up @@ -131,6 +136,44 @@ The next example is matching a pattern of batch_norm -> get(0) -> relu:
out = relay.nn.relu(tuple_get_item_node)
pat.match(out)
The next example is matching a constant node regarding its values. This is useful to check
if a specific parameter in a subgraph has been bind or not.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bound, can be fixed in the next PR


.. code-block:: python
def test_match_constant():
conv2d = is_op('nn.conv2d')(wildcard(), ConstantPattern())
pattern = is_op('nn.bias_add')(conv2d, wildcard())
x = relay.var('x', shape=(1, 3, 224, 224))
w = relay.var('w', shape=(3, 3, 3, 3))
b = relay.var('b', shape=(3, ))
conv2d = relay.op.nn.conv2d(x, w)
out = relay.op.nn.bias_add(conv2d, b)
func = relay.Function([x, w, b], out)
mod = tvm.IRModule.from_expr(func)
# Two inputs of the conv2d in the graph are VarNode by default, so no match.
assert not pattern.match(mod['main'].body)
# The second input (weight) has been bind with constant values so it is now a constant node.
mod["main"] = bind_params_by_name(mod["main"],
{'w': tvm.nd.array(np.ones(shape=(3, 3, 3, 3)))})
assert pattern.match(mod['main'].body)
On the other hand, if you need to match the constant with a specific value, you can directly
use ``ExprPattern``. This could be useful for algebraic simplify.

.. code-block:: python
def test_match_plus_zero():
zero = (ExprPattern(relay.const(0)) | ExprPattern(relay.const(0.0)))
pattern = wildcard() + zero
x = relay.Var('x')
y = x + relay.const(0)
assert pattern.match(y)
The next example is matching function nodes with a specific attribute:

.. code-block:: python
Expand Down
18 changes: 18 additions & 0 deletions include/tvm/relay/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,24 @@ class VarPattern : public DFPattern {
TVM_DEFINE_OBJECT_REF_METHODS(VarPattern, DFPattern, VarPatternNode);
};

/*!
* \brief A Pattern to Match a Relay Constant
*/
class ConstantPattern;
/*! \brief Container for Constant */
class ConstantPatternNode : public DFPatternNode {
public:
void VisitAttrs(tvm::AttrVisitor* v) {}

static constexpr const char* _type_key = "relay.dataflow_pattern.ConstantPattern";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantPatternNode, DFPatternNode);
};

class ConstantPattern : public DFPattern {
public:
TVM_DEFINE_OBJECT_REF_METHODS(ConstantPattern, DFPattern, ConstantPatternNode);
};

/*!
* \brief Call corresponds to operator invocation.
* Corresponds to the operator in computational graph terminology.
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/relay/dataflow_pattern_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const ConstantPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPatternDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
Expand All @@ -111,6 +112,7 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(ConstantPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode);
return vtable;
}
Expand All @@ -134,6 +136,7 @@ class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {
void VisitDFPattern_(const TuplePatternNode* op) override;
void VisitDFPattern_(const TypePatternNode* op) override;
void VisitDFPattern_(const VarPatternNode* op) override;
void VisitDFPattern_(const ConstantPatternNode* op) override;
void VisitDFPattern_(const WildcardPatternNode* op) override;

protected:
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,14 @@ def __init__(self, name_hint: str, type_annotation=None):
ffi.VarPattern, name_hint, type_annotation)


@register_df_node
class ConstantPattern(DFPattern):
"""A pattern matching a Relay Constant.
"""
def __init__(self):
self.__init_handle_by_constructor__(ffi.ConstantPattern)


@register_df_node
class CallPattern(DFPattern):
"""A pattern matching a function call node in Relay.
Expand Down
5 changes: 5 additions & 0 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Ex
bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;

void ClearMap(size_t watermark);
Expand Down Expand Up @@ -394,6 +395,10 @@ bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& exp
return matches;
}

bool DFPatternMatcher::VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) {
return expr.as<ConstantNode>() != nullptr;
}

bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) {
return true;
}
Expand Down
12 changes: 12 additions & 0 deletions src/relay/ir/dataflow_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ")";
});

TVM_REGISTER_NODE_TYPE(ConstantPatternNode);

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ConstantPattern").set_body_typed([]() {
auto c = ConstantPattern(make_object<ConstantPatternNode>());
return c;
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ConstantPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
p->stream << "ConstantPattern()";
});

CallPattern::CallPattern(DFPattern op, Array<DFPattern> args, Attrs attrs, Array<Type> type_args) {
ObjectPtr<CallPatternNode> n = make_object<CallPatternNode>();
n->op = std::move(op);
Expand Down
2 changes: 2 additions & 0 deletions src/relay/ir/dataflow_pattern_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ void DFPatternVisitor::VisitDFPattern_(const TypePatternNode* op) { VisitDFPatte

void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {}

void DFPatternVisitor::VisitDFPattern_(const ConstantPatternNode* op) {}

void DFPatternVisitor::VisitDFPattern_(const WildcardPatternNode* op) {}

} // namespace relay
Expand Down
2 changes: 2 additions & 0 deletions src/relay/ir/indexed_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const DFPattern& pattern) {

void VisitDFPattern_(const VarPatternNode* op, NodePtr parent) override {}

void VisitDFPattern_(const ConstantPatternNode* op, NodePtr parent) override {}

void VisitDFPattern_(const WildcardPatternNode* op, NodePtr parent) override {}
};
return Annotator(Creator().CreateGraph(pattern)).Annotate();
Expand Down
Loading