From 8fec0979354975c8ae58fd013d80647c3bd96e10 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 15 May 2020 12:42:51 -0700 Subject: [PATCH] Pattern Language, Matcher, Rewriter, and Function Paritioner (#5231) --- docs/langref/index.rst | 1 + docs/langref/relay_pattern.rst | 141 +++ include/tvm/relay/dataflow_matcher.h | 98 ++ include/tvm/relay/dataflow_pattern.h | 359 ++++++++ include/tvm/relay/dataflow_pattern_functor.h | 146 +++ python/tvm/relay/dataflow_pattern/__init__.py | 581 ++++++++++++ python/tvm/relay/dataflow_pattern/_ffi.py | 20 + src/relay/ir/dataflow_matcher.cc | 650 +++++++++++++ src/relay/ir/dataflow_pattern.cc | 219 +++++ src/relay/ir/dataflow_pattern_functor.cc | 77 ++ src/relay/ir/expr_functor.cc | 7 +- src/relay/ir/indexed_graph.cc | 277 ++++++ src/relay/ir/indexed_graph.h | 138 +++ tests/python/relay/test_dataflow_pattern.py | 853 ++++++++++++++++++ 14 files changed, 3563 insertions(+), 4 deletions(-) create mode 100644 docs/langref/relay_pattern.rst create mode 100644 include/tvm/relay/dataflow_matcher.h create mode 100644 include/tvm/relay/dataflow_pattern.h create mode 100644 include/tvm/relay/dataflow_pattern_functor.h create mode 100644 python/tvm/relay/dataflow_pattern/__init__.py create mode 100644 python/tvm/relay/dataflow_pattern/_ffi.py create mode 100644 src/relay/ir/dataflow_matcher.cc create mode 100644 src/relay/ir/dataflow_pattern.cc create mode 100644 src/relay/ir/dataflow_pattern_functor.cc create mode 100644 src/relay/ir/indexed_graph.cc create mode 100644 src/relay/ir/indexed_graph.h create mode 100644 tests/python/relay/test_dataflow_pattern.py diff --git a/docs/langref/index.rst b/docs/langref/index.rst index 0d296118da26..dcea9fa50c3d 100644 --- a/docs/langref/index.rst +++ b/docs/langref/index.rst @@ -46,6 +46,7 @@ algebraic data types, and operators in Relay, respectively. relay_type relay_adt relay_op + relay_pattern Hybrid Script ------------- diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst new file mode 100644 index 000000000000..7f81b9b48299 --- /dev/null +++ b/docs/langref/relay_pattern.rst @@ -0,0 +1,141 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +========================= +Pattern Matching in Relay +========================= + +There are many places in TVM where we identify pure data-flow sub-graphs of the Relay program and attempt to transform them in some way example passes include fusion, quantization, external code generation, and device specific optimizations such as bitpacking, and layer slicing used by VTA. + +Many of these passes today require a lots of boring boilerplate code in order to implement as well as requiring users to think in terms of visitors and AST matching. Many of these transformations can easily be described in terms of graph rewrites. In order to build a rewriter or other advanced machinery we first need a language of patterns to describe what we can match. + +Such a language is not just useful for building a rewriter but also providing extension points for existing passes. For example the fusion pass could be parameterized by a set of fusion patterns which describes the capability of your hardware, and the quantization pass could take a set of patterns which describe which operators can be quantized on a given platform. + +In the backend world, we could use the same machinery to build a higher level API using bring your own code generation. This API takes set of patterns describing your hardware capabilities and an external compiler, providing a relatively smooth heterogeneous experience out of the box. + +Examples +======== + +There are quite a few properties that are worth matching of operators below we examine how to match tree properties, and expand on some use cases that are not fully explored in the prototype. The first example is a simple case where we want to match one operator with a single input OR another operator with a single input, see the below diagram for a graphical representation and corresponding code:: + + def test_match_op_or(): + is_add_or_sub = is_op('add') | is_op('subtract') + assert is_add_or_sub.match(relay.op.op.get("add")) + assert is_add_or_sub.match(relay.op.op.get("subtract")) + +The next example is a dense operation with any operator that is marked element-wise:: + + def test_no_match_attr(): + op = is_op('nn.dense').has_attr("TOpPattern", K_ELEMWISE) + op_pat = op(wildcard(), wildcard()) + x = relay.var('x') + y = relay.var('y') + assert not op_pat.match(relay.op.nn.dense(x, y)) + +The next example is matching a diamond with two inputs at the top of the diamond:: + + def test_match_diamond(): + # Pattern + is_conv2d = is_op('nn.conv2d')(is_input(), is_input()) + path1 = is_op('nn.relu')(is_conv2d) + path2 = is_op('nn.leaky_relu')(is_conv2d) + diamond = is_op('add')(path1, path2) + + # Expr + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Check + assert diamond.match(out) + +The final example is matching diamonds with a post-dominator relationship. We embed dominator analysis as type of matching in the pattern language in order to allow for pattern matching with unknown topology. This is important because we want to be able to use the language to describe fuse patterns, like elementwise operations followed by a conv2d:: + + def test_match_dom_diamond(): + # Pattern + is_conv2d = is_op('nn.conv2d')(is_input(), is_input()) + reduction = is_op('add')(wildcard(), wildcard()) + diamond = dominates(is_conv2d, is_elemwise, reduction) + + # Expr + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Check + assert diamond.match(out) + +Design +====== + +The pattern language proposed is designed to be a mirror of Relay's IR with additional support for common scenarios. The goal of the pattern language is to provide a regular-expression like capability for matching data-flow graphs and doing rewriting. + +The high level design is to introduce a language of patterns for now we propose the language as:: + + Pattern ::= expr + | * + | pattern(pattern1, ... patternN) + | has_type(pattern, type) + | has_attr(pattern, attr, attr_value) + | is_input(name) + | pattern1 `|` pattern2 + | dominates(parent_pattern, path_pattern, child_pattern) + +The above language then provides a matching interface with both can select sub-graphs as well as verify that the graph does match the pattern. + +Expression Pattern +****************** + +Match a literal expression. + +Wildcard +******** + +Match any expression. + +Type Pattern +************ + +Check that the expression matched by the nested pattern has a particular type. + +Attribute Pattern +***************** + +Check that the operator matched by the pattern has an attribute with a particular value. + +Input +***** + +Check that the expression is an input, i.e has no parents and is a variable. + + +Alternate +********* + +Either match the first pattern or the second pattern. + +Domination +********** + +Match child pattern, find a match for the parent pattern, insuring that the child ultimately dominates the parrent (i.e., no nodes outside the pattern use outputs of the parent), and that ever node betwen the child and the pattern matches the path pattern. diff --git a/include/tvm/relay/dataflow_matcher.h b/include/tvm/relay/dataflow_matcher.h new file mode 100644 index 000000000000..58aa6400b650 --- /dev/null +++ b/include/tvm/relay/dataflow_matcher.h @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/dataflow_matcher.h + * \brief A pattern matcher for matching dataflow properties. + */ +#ifndef TVM_RELAY_DATAFLOW_MATCHER_H_ +#define TVM_RELAY_DATAFLOW_MATCHER_H_ + +#include +#include + +#include +#include + +namespace tvm { +namespace relay { + +class DFPatternCallback; +/*! + * \brief Base type of all dataflow pattern callbacks. + * \sa DFPatternCallback + */ +class DFPatternCallbackNode : public Object { + public: + /*! \brief Pattern this callback matches */ + DFPattern pattern_; + /*! \brief Function to call when finding a matched expression */ + PackedFunc function_; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "DFPatternCallbackNode"; + TVM_DECLARE_BASE_OBJECT_INFO(DFPatternCallbackNode, Object); +}; + +/*! + * \brief Managed reference to dataflow pattern callbacks. + * \sa DFPatternCallbackNode + */ +class DFPatternCallback : public ObjectRef { + public: + TVM_DLL DFPatternCallback(DFPattern pattern, PackedFunc callback); + TVM_DEFINE_OBJECT_REF_METHODS(DFPatternCallback, ObjectRef, DFPatternCallbackNode); +}; + +/*! + * \brief Determine if a pattern matches an expression + * + * \param pattern The pattern to match + * \param expr The expression to match + * + * \return Return true if the pattern and the expression match, return false otherwise. + */ +bool MatchPattern(DFPattern pattern, Expr expr); + +/*! + * \brief Rewrite an expression based on some number of DFPatternCallbacks + * + * \param callbacks An array of DFPatternCallback Nodes + * \param expr The expression to rewrite + * + * \return Return An Expr with every match of the pattern inside the callbacks rewritten by the + * functions inside the callbacks + */ +Expr RewritePatterns(Array callbacks, Expr expr); + +/*! + * \brief Partition all matches of a DFPattern inside an Expr into separate Function calls + * + * \param pattern The pattern to match + * \param expr The expression to patition + * + * \return Return the paritioned Expr. + */ +Expr PartitionPattern(DFPattern pattern, Expr expr); + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_DATAFLOW_MATCHER_H_ diff --git a/include/tvm/relay/dataflow_pattern.h b/include/tvm/relay/dataflow_pattern.h new file mode 100644 index 000000000000..a8db51f74574 --- /dev/null +++ b/include/tvm/relay/dataflow_pattern.h @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/dataflow_pattern.h + * \brief A pattern language for matching dataflow properties. + */ +#ifndef TVM_RELAY_DATAFLOW_PATTERN_H_ +#define TVM_RELAY_DATAFLOW_PATTERN_H_ + +#include +#include + +namespace tvm { +namespace relay { + +/*! + * \brief Base type of all dataflow patterns. + * \sa DFPattern + */ +class DFPatternNode : public Object { + public: + static constexpr const char* _type_key = "DFPatternNode"; + TVM_DECLARE_BASE_OBJECT_INFO(DFPatternNode, Object); +}; + +/*! + * \brief Managed reference to dataflow patterns. + * \sa DFPatternNode + */ +class DFPattern : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(DFPattern, ObjectRef, DFPatternNode); +}; + +/*! + * \brief Pattern for Relay Expression. + */ +class ExprPatternNode : public DFPatternNode { + public: + /*! \brief The expression to match. */ + Expr expr; + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("expr", &expr); } + + static constexpr const char* _type_key = "relay.dataflow_pattern.ExprPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(ExprPatternNode, DFPatternNode); +}; + +/*! + * \brief A pattern which matches a literal expression. + * + * \note Uses structural equality on expressions to check equality. + * + */ +class ExprPattern : public DFPattern { + public: + TVM_DLL explicit ExprPattern(Expr expr); + TVM_DEFINE_OBJECT_REF_METHODS(ExprPattern, DFPattern, ExprPatternNode); +}; + +/*! + * \brief A Pattern to Match a Relay Variable + */ +class VarPattern; +/*! \brief Container for Var */ +class VarPatternNode : public DFPatternNode { + public: + /*! + * \brief The name of the Var (optional). + */ + String name; + /*! + * \brief type annotation of the variable. + * This field records user provided type annotation of the Var. + * This field is optional and can be None. + */ + Type type_annotation; + + /*! \return The name hint of the variable */ + const String& name_hint() const { return name; } + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("type_annotation", &type_annotation); + } + + static constexpr const char* _type_key = "relay.dataflow_pattern.VarPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(VarPatternNode, DFPatternNode); +}; + +class VarPattern : public DFPattern { + public: + TVM_DLL VarPattern(String name_hint, Type type_annotation); + TVM_DEFINE_OBJECT_REF_METHODS(VarPattern, DFPattern, VarPatternNode); +}; + +/*! + * \brief Call corresponds to operator invocation. + * Corresponds to the operator in computational graph terminology. + */ +class CallPattern; +/*! \brief CallPattern container. */ +class CallPatternNode : public DFPatternNode { + public: + /*! + * \brief The operator(function) being invoked + * + * - It can be relay::Op which corresponds to the primitive operators. + * - It can also be user defined functions (Function, GlobalVar, Var). + */ + DFPattern op; + + /*! \brief The arguments(inputs) of the call */ + tvm::Array args; + + /*! \brief The additional attributes */ + Attrs attrs; + + /*! + * \brief The type arguments passed to polymorphic(template) function. + * + * This is the advance feature that is only used when the function is + * polymorphic. It is safe to be ignored in most cases. For example, in the + * following code, the type_args of addone call is [int]. + * + * \code + * + * template + * T addone(T a) { return a + 1; } + * + * void main() { + * int x = addone(10); + * } + * + * \endcode + */ + tvm::Array type_args; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("op", &op); + v->Visit("args", &args); + v->Visit("attrs", &attrs); + v->Visit("type_args", &type_args); + } + + static constexpr const char* _type_key = "relay.dataflow_pattern.CallPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(CallPatternNode, DFPatternNode); +}; + +class CallPattern : public DFPattern { + public: + TVM_DLL CallPattern(DFPattern op, Array args, Attrs attrs, Array type_args); + TVM_DEFINE_OBJECT_REF_METHODS(CallPattern, DFPattern, CallPatternNode); +}; + +/*! \brief Tuple of multiple Exprs */ +class TuplePattern; +/*! \brief Tuple container */ +class TuplePatternNode : public DFPatternNode { + public: + /*! \brief the fields of the tuple */ + tvm::Array fields; + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); } + + static constexpr const char* _type_key = "relay.dataflow_pattern.TuplePattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(TuplePatternNode, DFPatternNode); +}; + +class TuplePattern : public DFPattern { + public: + TVM_DLL explicit TuplePattern(tvm::Array fields); + TVM_DEFINE_OBJECT_REF_METHODS(TuplePattern, DFPattern, TuplePatternNode); +}; + +/*! \brief Get index-th field out of a tuple. */ +class TupleGetItemPattern; +class TupleGetItemPatternNode : public DFPatternNode { + public: + /*! \brief The tuple Expression */ + DFPattern tuple; + /*! \brief which value to get */ + int index; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("tuple", &tuple); + v->Visit("index", &index); + } + + static constexpr const char* _type_key = "relay.dataflow_pattern.TupleGetItemPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemPatternNode, DFPatternNode); +}; + +class TupleGetItemPattern : public DFPattern { + public: + TVM_DLL TupleGetItemPattern(DFPattern tuple, int index); + TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItemPattern, DFPattern, TupleGetItemPatternNode); +}; + +class AltPattern; +/*! + * \brief Pattern for Alternate Expressions. + */ +class AltPatternNode : public DFPatternNode { + public: + /*! \brief The left optional pattern. */ + DFPattern left; + /*! \brief The right optional pattern. */ + DFPattern right; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("left", &left); + v->Visit("right", &right); + } + + static constexpr const char* _type_key = "relay.dataflow_pattern.AltPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(AltPatternNode, DFPatternNode); +}; + +/*! + * \brief A pattern which matches either of two patterns + */ +class AltPattern : public DFPattern { + public: + TVM_DLL AltPattern(DFPattern left, DFPattern right); + TVM_DEFINE_OBJECT_REF_METHODS(AltPattern, DFPattern, AltPatternNode); +}; + +/*! + * \brief Wildcard Pattern. + */ +class WildcardPatternNode : public DFPatternNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "relay.dataflow_pattern.WildcardPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(WildcardPatternNode, DFPatternNode); +}; + +/*! + * \brief A pattern which matches anything. + */ +class WildcardPattern : public DFPattern { + public: + TVM_DEFINE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode); +}; + +class TypePattern; +/*! + * \brief Pattern for Types. + */ +class TypePatternNode : public DFPatternNode { + public: + /*! \brief The pattern. */ + DFPattern pattern; + /*! \brief The type to match */ + Type type; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("type", &type); + } + + static constexpr const char* _type_key = "relay.dataflow_pattern.TypePattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(TypePatternNode, DFPatternNode); +}; + +/*! + * \brief A pattern which matches a type in another pattern + */ +class TypePattern : public DFPattern { + public: + TVM_DLL TypePattern(DFPattern pattern, Type type); + TVM_DEFINE_OBJECT_REF_METHODS(TypePattern, DFPattern, TypePatternNode); +}; + +class AttrPattern; +/*! + * \brief Pattern for Attributes. + */ +class AttrPatternNode : public DFPatternNode { + public: + /*! \brief The pattern. */ + DFPattern pattern; + /*! \brief The attribute to match */ + Attrs attrs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("attrs", &attrs); + } + + static constexpr const char* _type_key = "relay.dataflow_pattern.AttrPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(AttrPatternNode, DFPatternNode); +}; + +/*! + * \brief A pattern which matches attributes in another pattern + */ +class AttrPattern : public DFPattern { + public: + TVM_DLL AttrPattern(DFPattern pattern, Attrs attrs); + TVM_DEFINE_OBJECT_REF_METHODS(AttrPattern, DFPattern, AttrPatternNode); +}; + +class DominatorPattern; +/*! + * \brief Dominated Graph Pattern + * Pattern for fuzzy subgraphs where all outputs of the parent are used finally by the child, and + * every operation between the parent and the child matches the path. + */ +class DominatorPatternNode : public DFPatternNode { + public: + /*! \brief The parent. */ + DFPattern parent; + /*! \brief The path. */ + DFPattern path; + /*! \brief The child. */ + DFPattern child; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("parent", &parent); + v->Visit("path", &path); + v->Visit("child", &child); + } + + static constexpr const char* _type_key = "relay.dataflow_pattern.DominatorPattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(DominatorPatternNode, DFPatternNode); +}; + +/*! + * \brief A pattern which matches a variable length dominator path + */ +class DominatorPattern : public DFPattern { + public: + TVM_DLL DominatorPattern(DFPattern parent, DFPattern path, DFPattern child); + TVM_DEFINE_OBJECT_REF_METHODS(DominatorPattern, DFPattern, DominatorPatternNode); +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_DATAFLOW_PATTERN_H_ diff --git a/include/tvm/relay/dataflow_pattern_functor.h b/include/tvm/relay/dataflow_pattern_functor.h new file mode 100644 index 000000000000..05c2147c2c49 --- /dev/null +++ b/include/tvm/relay/dataflow_pattern_functor.h @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/dataflow_pattern_functor.h + * \brief A set of passes for operating on pattern graphs. + */ +#ifndef TVM_RELAY_DATAFLOW_PATTERN_FUNCTOR_H_ +#define TVM_RELAY_DATAFLOW_PATTERN_FUNCTOR_H_ + +#include + +#include +#include + +namespace tvm { +namespace relay { + +/*! + * \brief A dynamical functor that dispatches on in the first DFPattern argument. + * + * \tparam FType function signature + * This type is only defined for FType with function signature R(const DFPattern&, + * Args...) + */ +template +class DFPatternFunctor; + +// functions to be overriden. +#define DFPATTERN_FUNCTOR_DEFAULT \ + { return VisitDFPatternDefault_(op, std::forward(args)...); } + +#define RELAY_DFPATTERN_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) { \ + return self->VisitDFPattern_(static_cast(n.get()), std::forward(args)...); \ + }); + +template +class DFPatternFunctor { + private: + using TSelf = DFPatternFunctor; + using FType = tvm::NodeFunctor; + + public: + /*! \brief virtual destructor */ + virtual ~DFPatternFunctor() {} + /*! + * \brief Same as call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + R operator()(const DFPattern& n, Args... args) { + return VisitDFPattern(n, std::forward(args)...); + } + /*! + * \brief The functor call. + * \param n The expression node. + * \param args Additional arguments. + * \return The result of the call + */ + virtual R VisitDFPattern(const DFPattern& n, Args... args) { + CHECK(n.defined()); + static FType vtable = InitVTable(); + return vtable(n, this, std::forward(args)...); + } + // Functions that can be overriden by subclass + virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const TupleGetItemPatternNode* op, + Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPatternDefault_(const Object* op, Args...) { + LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); + throw; + } + + private: + // initialize the vtable. + static FType InitVTable() { + FType vtable; + // Set dispatch + RELAY_DFPATTERN_FUNCTOR_DISPATCH(AltPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(CallPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode); + return vtable; + } +}; + +/*! + * \brief A simple visitor wrapper around DFPatternFunctor. + * Recursively visit the content. + * + * DFPatternVisitor treats the Pattern as dataflow graph,and only visit each Expr node once. + */ +class DFPatternVisitor : public DFPatternFunctor { + public: + void VisitDFPattern(const DFPattern& pattern) override; + void VisitDFPattern_(const AltPatternNode* op) override; + void VisitDFPattern_(const AttrPatternNode* op) override; + void VisitDFPattern_(const CallPatternNode* op) override; + void VisitDFPattern_(const DominatorPatternNode* op) override; + void VisitDFPattern_(const ExprPatternNode* op) override; + void VisitDFPattern_(const TupleGetItemPatternNode* op) override; + void VisitDFPattern_(const TuplePatternNode* op) override; + void VisitDFPattern_(const TypePatternNode* op) override; + void VisitDFPattern_(const VarPatternNode* op) override; + void VisitDFPattern_(const WildcardPatternNode* op) override; + + protected: + // set of already-visited nodes + std::unordered_set visited_; +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_DATAFLOW_PATTERN_FUNCTOR_H_ diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py new file mode 100644 index 000000000000..ca324bc444ec --- /dev/null +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -0,0 +1,581 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The Relay Pattern Language and tooling.""" +from tvm.relay import Expr +import tvm._ffi +from ...ir.base import Node +from ...ir import make_node +from ...runtime import Object +from ... import _ffi as tvm_ffi +from ..op import get +from . import _ffi as ffi + + +def register_df_node(type_key=None): + """Register a Relay node type. + + Parameters + ---------- + type_key : str or cls + The type key of the node. + """ + if not isinstance(type_key, str): + return tvm._ffi.register_object( + "relay.dataflow_pattern." + type_key.__name__)(type_key) + return tvm._ffi.register_object(type_key) + + +class DFPattern(Node): + """Base class of all Patterns. + """ + + def __call__(self, *args): + return CallPattern(self, list(args)) + + def __or__(self, other): + return AltPattern(self, other) + + def __add__(self, other): + return is_op("add")(self, other) + + def __sub__(self, other): + return is_op("subtract")(self, other) + + def __mul__(self, other): + return is_op("multiply")(self, other) + + def __truediv__(self, other): + return is_op("divide")(self, other) + + def has_attr(self, attr_name: str, attr_value): + """ + Add an attribute constraint to this pattern + + Parameters + ---------- + attr_name: str + The name of the attribute to match + attr_value: Any + The value of the attribute to match + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting AttrPattern + """ + attrs = make_node("DictAttrs", **{attr_name: attr_value}) + return AttrPattern(self, attrs) + + def has_type(self, ttype): + """ + Add a type constraint to this pattern + + Parameters + ---------- + ttype: tvm.relay.Type + The type to match + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting TypePattern + """ + return has_type(ttype, self) + + def match(self, expr: Expr) -> bool: + """ + Match this pattern to an expression + + Parameters + ---------- + expr : tvm.relay.Expr + The expression to match. + + Returns + ------- + result: bool + Whether or not the expression matches the pattern + """ + return match(self, expr) + + def partition(self, expr: Expr) -> bool: + """ + Parition the expression into functions defined by this pattern + + Parameters + ---------- + expr : tvm.relay.Expr + The expression to match. + + Returns + ------- + result : tvm.relay.Expr + The Expression with matched subgraphs replaced by function calls to that subgraph + """ + return partition(self, expr) + + def dominates(self, parent, path=None): + """ + Create a dominator for this pattern + + Parameters + ---------- + parent: tvm.relay.dataflow_pattern.DFPattern + The parent pattern this pattern dominates. + path: tvm.relay.dataflow_pattern.DFPattern + The fuzzy path pattern. + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting DominatorPattern + """ + if path is None: + path = wildcard() + return DominatorPattern(parent, path, self) + + def optional(self, option_constructor): + """ + Create a optional user of this pattern + + Parameters + ---------- + option_constructor: function + A function that takes a single Pattern parameter and returns + a constructed pattern matching the option + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting Pattern + """ + return self | option_constructor(self) + + +def is_input(name: str = "") -> DFPattern: + """ + Syntatic sugar for creating an optionally named VarPattern + + Parameters + ---------- + name: str + The name of the input pattern to match + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting InputPattern + """ + return VarPattern(name) + + +def is_op(op_name: str) -> DFPattern: + """ + Syntatic sugar for creating an operator ExprPattern + + Parameters + ---------- + op_name: String + The name of the relay op + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting ExprPattern + """ + op = get(op_name) + return ExprPattern(op) + + +def wildcard() -> DFPattern: + """ + Syntatic sugar for creating a WildcardPattern + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting WildcardPattern + """ + return WildcardPattern() + + +def has_type(ttype, pattern: DFPattern = None) -> DFPattern: + """ + Syntatic sugar for creating a TypePattern + + Parameters + ---------- + pattern: tvm.relay.dataflow_pattern.DFPattern + The pattern that needs type annotation + + ttype: tvm.relay.Type + The type to match + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting TypePattern + """ + if pattern is None: + pattern = wildcard() + return TypePattern(pattern, ttype) + + +def has_attr(attr_name: DFPattern, attr_value, pattern=None) -> DFPattern: + """ + Syntatic sugar for creating an AttrPattern + + Parameters + ---------- + pattern: tvm.relay.dataflow_pattern.DFPattern + The input pattern. + + attrs: tvm.Attrs + The attributes to match + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting AttrPattern + """ + if pattern is None: + pattern = wildcard() + return pattern.has_attr(attr_name, attr_value) + + +def dominates(parent: DFPattern, path: DFPattern, child: DFPattern) -> DFPattern: + """ + Syntatic sugar for creating an Dominator pattern + + Parameters + ---------- + parent: tvm.relay.dataflow_pattern.DFPattern + The parent pattern. + path: tvm.relay.dataflow_pattern.DFPattern + The fuzzy path pattern. + child: tvm.relay.dataflow_pattern.DFPattern + The child pattern. + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting DominatorPattern + """ + return DominatorPattern(parent, path, child) + + +def match(pattern: DFPattern, expr: Expr) -> bool: + """ + Match a pattern to an expression + + Parameters + ---------- + pattern: tvm.relay.dataflow_pattern.DFPattern + The input pattern. + expr : tvm.relay.Expr + The expression to match. + """ + return ffi.match(pattern, expr) + + +@register_df_node +class ExprPattern(DFPattern): + """A pattern which matches a constant expression. + + Parameters + ---------- + expr : tvm.relay.Expr + The expression to match. + """ + + def __init__(self, expr: Expr): + self.__init_handle_by_constructor__(ffi.ExprPattern, expr) + + +@register_df_node +class VarPattern(DFPattern): + """A local variable in Relay. + + Local variable can be used to declare input + arguments to a function, or intermediate variables. + + Parameters + ---------- + name_hint: str + The name of the variable. + This name only acts as a hint, and is not used + for equality. + + type_annotation: tvm.relay.Type, optional + The type annotation on the variable. + """ + + def __init__(self, name_hint: str, type_annotation=None): + self.__init_handle_by_constructor__( + ffi.VarPattern, name_hint, type_annotation) + + +@register_df_node +class CallPattern(DFPattern): + """A pattern matching a function call node in Relay. + + Parameters + ---------- + op: realy.dataflow_pattern.DFPattern + The operation to be called. + + args: List[realy.dataflow_pattern.DFPattern] + The arguments to the call. + + attrs: Optional[tvm.Attrs] + Attributes to the call, can be None + + type_args: Optional[List[tvm.relay.Type]] + The additional type arguments, this is only + used in advanced usecase of template functions. + """ + + def __init__(self, op, args, attrs=None, type_args=None): + if not type_args: + type_args = [] + self.__init_handle_by_constructor__( + ffi.CallPattern, op, args, attrs, type_args) + + +@register_df_node +class TuplePattern(DFPattern): + """A patern matching a Relay Tuple. + + Parameters + ---------- + fields : List[tvm.relay.dataflow_pattern.DFPattern] + The fields in the tuple. + """ + + def __init__(self, fields): + self.__init_handle_by_constructor__(ffi.TuplePattern, fields) + + def __getitem__(self, index): + if index >= len(self): + raise IndexError("TuplePattern index out of range") + return self.fields[index] + + def __len__(self): + return len(self.fields) + + def astype(self, _): + raise TypeError("astype cannot be used on TuplePattern") + + +@register_df_node +class TupleGetItemPattern(DFPattern): + """Get index-th item from a TuplePattern. + + Parameters + ---------- + tuple_value: tvm.relay.dataflow_pattern.DFPattern + The input tuple expression. + + index: int + The index. + """ + + def __init__(self, tuple_value: DFPattern, index): + self.__init_handle_by_constructor__( + ffi.TupleGetItemPattern, tuple_value, index) + + +@register_df_node +class AltPattern(DFPattern): + """Create a Pattern that can match one of two conditions + + Parameters + ---------- + left: tvm.relay.dataflow_pattern.DFPattern + One possible matching Pattern + right: tvm.relay.dataflow_pattern.DFPattern + One possible matching Pattern + """ + + def __init__(self, left: DFPattern, right: DFPattern): + self.__init_handle_by_constructor__( + ffi.AltPattern, left, right) + + +@register_df_node +class WildcardPattern(DFPattern): + """A pattern which matches anything. + """ + + def __init__(self): + self.__init_handle_by_constructor__(ffi.WildcardPattern) + + +@register_df_node +class TypePattern(DFPattern): + """Get index-th item from a TuplePattern. + + Parameters + ---------- + pattern: tvm.relay.dataflow_pattern.DFPattern + The input pattern that needs type annotation + + ttype: tvm.relay.Type + The type to match + """ + + def __init__(self, pattern: DFPattern, ttype): + self.__init_handle_by_constructor__( + ffi.TypePattern, pattern, ttype) + + +@register_df_node +class AttrPattern(DFPattern): + """Get match an expression with a certain attributes. + Currently only supports Op Attributes, not call Attributes + + Parameters + ---------- + pattern: tvm.relay.dataflow_pattern.DFPattern + The input pattern. + + attrs: tvm.Attrs + The attributes to match + """ + + def __init__(self, pattern: DFPattern, attrs): + self.__init_handle_by_constructor__( + ffi.AttrPattern, pattern, attrs) + + +@register_df_node +class DominatorPattern(DFPattern): + """Match a domination graph. + + Parameters + ---------- + parent: tvm.relay.dataflow_pattern.DFPattern + The parent, i.e., the single node which produces something, + later aggregated by the child + path: tvm.relay.dataflow_pattern.DFPattern + The fuzzy path pattern between parent and child, + typically matches elementwise ops + child: tvm.relay.dataflow_pattern.DFPattern + The last node in the domination which is the end user + for all nodes in the path and the parent + """ + + def __init__(self, parent: DFPattern, path: DFPattern, child: DFPattern): + self.__init_handle_by_constructor__( + ffi.DominatorPattern, parent, path, child) + + +class DFPatternCallback: + """A Callback for Pattern Rewriting + + When rewrite is called on this DFPatternCallback, the backend will find matches for the + pattern, call the callback function, and replace the matched expression with whatever + the callback returns. + + Users are expect to inherit from this class and provide a "self.pattern" to match + """ + + def rewrite(self, expr: Expr) -> Expr: + """ + Rewrite expression with this callback + + Parameters + ---------- + expr : tvm.relay.Expr + The expression to rewrite. + + Returns + ------- + result : tvm.relay.Expr + The Expression with matched subgraphs rewritten by the callbacks + """ + return rewrite(self, expr) + + def callback(self, pre, post, node_map): + """ + Callback function to use when we found a match to the pattern + + Parameters + ---------- + pre : tvm.relay.Expr + The matching expression from the original graph. + post : tvm.relay.Expr + The matching expression with rewritten inputs + node_map : Map(DFPattern, List(Expr)) + The map between patterns and matched expressions + + Returns + ------- + result : tvm.relay.Expr + The Expression with matched subgraph rewritten by the callback + """ + raise "Unimplemented" + +class _DFPatternCallback(Object): + """C++ implemenation""" + def __init__(self, pattern, callback): + self.__init_handle_by_constructor__( + ffi.DFPatternCallback, pattern, callback) + + +def rewrite(callbacks, expr: Expr) -> Expr: + """ + Rewrite expression with the given callbacks + + Parameters + ---------- + callbacks: tvm.relay.dataflow_pattern.DFPatternCallback + The input callback or list of callbacks. + expr : tvm.relay.Expr + The expression to rewrite. + + Returns + ------- + result : tvm.relay.Expr + The Expression with matched subgraphs rewritten by the callbacks + """ + if isinstance(callbacks, DFPatternCallback): + tmp = [_DFPatternCallback(callbacks.pattern, callbacks.callback)] + else: + tmp = [] + for callback in callbacks: + tmp.append(_DFPatternCallback(callback.pattern, callback.callback)) + + return ffi.rewrite(tmp, expr) + +def partition(pattern: DFPattern, expr: Expr) -> Expr: + """ + Parition the expression into a series of functions that match the pattern + + Parameters + ---------- + partion: tvm.relay.dataflow_pattern.DFPattern + The pattern to match + expr : tvm.relay.Expr + The expression to split into functions + + Returns + ------- + result : tvm.relay.Expr + The Expression with matched subgraphs replaced by function calls to that subgraph + """ + return ffi.partition(pattern, expr) diff --git a/python/tvm/relay/dataflow_pattern/_ffi.py b/python/tvm/relay/dataflow_pattern/_ffi.py new file mode 100644 index 000000000000..b0a702c1d2f5 --- /dev/null +++ b/python/tvm/relay/dataflow_pattern/_ffi.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""DataFlow Pattern Language FFI bindings.""" +import tvm._ffi + +tvm._ffi._init_api("relay.dataflow_pattern", __name__) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc new file mode 100644 index 000000000000..81fc4f03d886 --- /dev/null +++ b/src/relay/ir/dataflow_matcher.cc @@ -0,0 +1,650 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/relay/dataflow_matcher.cc + * \brief The dataflow pattern matcher for Relay. + */ + +#include +#include +#include +#include + +#include + +#include "indexed_graph.h" + +namespace tvm { +namespace relay { + +// Pattern Matcher + +class DominatorMatcher; + +class DFPatternMatcher : public DFPatternFunctor { + public: + explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {} + bool Match(const DFPattern& pattern, const Expr& expr); + Map> GetMemo() { return Map>(memo_); } + + protected: + bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; + bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override; + + void ClearMap(size_t watermark); + bool MatchesPath(const DominatorPatternNode* op, const Expr& expr); + bool DominatesParent(const DominatorPatternNode* op, const Expr& expr); + + std::unordered_map, ObjectHash, ObjectEqual> memo_; + std::vector matched_nodes_; + IndexedGraph expr_graph_; + bool memoize_ = true; +}; + +bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { + memo_.clear(); + matched_nodes_.clear(); + return VisitDFPattern(pattern, expr); +} + +void DFPatternMatcher::ClearMap(size_t watermark) { + for (size_t i = watermark; i < matched_nodes_.size(); ++i) { + memo_.erase(matched_nodes_[i]); + } + matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end()); +} + +bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) { + if (memoize_ && memo_.count(pattern)) { + CHECK_EQ(memo_[pattern].size(), 1); + return expr.same_as(memo_[pattern][0]); + } else { + auto watermark = matched_nodes_.size(); + auto out = DFPatternFunctor::VisitDFPattern(pattern, expr); + if (out) { + memo_[pattern].push_back(expr); + matched_nodes_.push_back(pattern); + } else { + ClearMap(watermark); + } + return out; + } +} + +bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) { + return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr); +} + +bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) { + bool matches = false; + if (const auto* op_node = expr.as()) { + Op op = GetRef(op_node); + auto attributes = attr_pattern->attrs.as()->dict; + for (auto kv : attributes) { + auto attr_name = kv.first; + auto attr_value = kv.second; + auto op_map = Op::GetAttr(attr_name); + if (op_map.count(op)) { + switch (op_map[op].type_code()) { + case kDLInt: + if (auto* val = kv.second.as()) { + matches = val->value == op_map[op].operator int64_t(); + } + break; + case kDLFloat: + if (auto* val = kv.second.as()) { + matches = val->value == op_map[op].operator double(); + } + break; + case kTVMStr: + if (auto* val = kv.second.as()) { + matches = val->value == op_map[op].operator std::string(); + } + break; + default: + CHECK(false) << "Unsupported type in Type Pattern Node"; + } + } + } + } + return matches; +} + +Array reverse(const Array& args) { + Array new_args; + for (auto it = args.rbegin(); it != args.rend(); ++it) { + new_args.push_back(*it); + } + return new_args; +} + +bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) { + // utilities + auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* { + if (op) { + if (auto* expr_pattern = op->op.as()) { + return expr_pattern->expr.as(); + } + } + return nullptr; + }; + auto is_pattern_op = [&get_op_node](const CallPatternNode* op, std::string op_type) { + if (const auto* op_node = get_op_node(op)) { + if (op_node->name == op_type) { + return true; + } + } + return false; + }; + auto is_expr_op = [](const Expr& expr, std::string op_type) { + if (const auto* call_node = expr.as()) { + if (const auto* op_node = call_node->op.as()) { + if (op_node->name == op_type) { + return true; + } + } + } + return false; + }; + // logic + auto watermark = matched_nodes_.size(); + if (const auto* call_node = expr.as()) { + auto matches_op = VisitDFPattern(op->op, call_node->op); + if (matches_op) { + auto watermark2 = matched_nodes_.size(); + + auto match_args = [this, &watermark2](const Array pattern_args, + const Array expr_args) { + bool matches = true; + size_t i = 0; + if (pattern_args.size() == expr_args.size()) { + while (matches && i < pattern_args.size()) { + matches &= VisitDFPattern(pattern_args[i], expr_args[i]); + ++i; + } + } else { + matches = false; + } + if (!matches) { + ClearMap(watermark2); + } + return matches; + }; + + // Standard case + if (match_args(op->args, call_node->args)) { + return true; + } + // Commutative Matching + if (const OpNode* op_node = get_op_node(op)) { + if ((op_node->name == "add") || (op_node->name == "multiply")) { + if (match_args(reverse(op->args), call_node->args)) { + return true; + } + } + } + } else { + ClearMap(watermark); + // associate divide/multiply + if (is_pattern_op(op, "divide")) { + if (const auto* arg_node = op->args[0].as()) { + if (is_pattern_op(arg_node, "multiply") && is_expr_op(expr, "multiply") && + (is_expr_op(call_node->args[0], "divide") || + is_expr_op(call_node->args[1], "divide"))) { + bool out = false; + for (size_t arg_id = 0; arg_id < 2; ++arg_id) { + auto div = CallPattern(op->op, {arg_node->args[arg_id], op->args[1]}, op->attrs, + op->type_args); + auto mul = CallPattern(arg_node->op, {arg_node->args[(arg_id + 1) % 2], div}, + arg_node->attrs, arg_node->type_args); + out = VisitDFPattern(mul, expr); + if (out) { + return true; + } else { + ClearMap(watermark); + } + } + return out; + } + } + } + if (is_pattern_op(op, "multiply")) { + // associate multiply/divide + for (size_t arg_id = 0; arg_id < 2; ++arg_id) { + if (auto* arg_node = op->args[arg_id].as()) { + if (is_pattern_op(arg_node, "divide") && is_expr_op(expr, "divide") && + (is_expr_op(call_node->args[0], "multiply") || + is_expr_op(call_node->args[1], "multiply"))) { + auto mul = CallPattern(op->op, {arg_node->args[0], op->args[(arg_id + 1) % 2]}, + op->attrs, op->type_args); + auto div = CallPattern(arg_node->op, {mul, arg_node->args[1]}, arg_node->attrs, + arg_node->type_args); + return VisitDFPattern(div, expr); + } + } + } + } + } + } + return false; +} + +// Recursively find the Dominator parent along all inputs paths. +bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) { + auto call_node = expr.as(); + for (auto node : expr_graph_.node_map_[expr]->inputs_) { + if (!(call_node && node->ref_ == call_node->op)) { + memoize_ = true; + if (VisitDFPattern(op->parent, node->ref_)) { + return true; + } else { + memoize_ = false; + if (!VisitDFPattern(op->path, node->ref_) || !MatchesPath(op, node->ref_)) { + return false; + } + } + } + } + return true; +} + +// Iteratively ensure that the parent is dominated somewhere by the child or the path +bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) { + std::stack stack; + std::unordered_set visited; + stack.push(expr); + while (!stack.empty()) { + Expr current = stack.top(); + stack.pop(); + for (auto node : expr_graph_.node_map_[current]->dominator_children_) { + if (visited.count(node->ref_) == 0) { + if (VisitDFPattern(op->parent, node->ref_)) { + return true; + } else { + stack.push(node->ref_); + } + visited.insert(node->ref_); + } + } + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) { + if (VisitDFPattern(op->child, expr)) { + bool matches_path = MatchesPath(op, expr); + memoize_ = true; + if (matches_path) { + return DominatesParent(op, expr); + } + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) { + return StructuralEqual()(op->expr, expr); +} + +bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) { + bool matches = false; + if (const auto* tuple_get_item_node = expr.as()) { + matches = (op->index == tuple_get_item_node->index) && + VisitDFPattern(op->tuple, tuple_get_item_node->tuple); + } + return matches; +} + +bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) { + bool matches = false; + if (const auto* tuple_node = expr.as()) { + if (op->fields.size() == tuple_node->fields.size()) { + matches = true; + size_t i = 0; + while (matches && i < op->fields.size()) { + matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]); + ++i; + } + } + } + return matches; +} + +Expr InferType(const Expr& expr) { + auto mod = IRModule::FromExpr(expr); + mod = transform::InferType()(mod); + if (expr.as()) { + return mod->Lookup("main"); + } else { + return mod->Lookup("main").as()->body; + } +} + +bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr) { + auto expr_type = InferType(expr).as()->checked_type(); + return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr); +} + +bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) { + bool matches = false; + if (const auto* var_node = expr.as()) { + matches = true; + if (op->name_hint() != "") { + matches &= op->name_hint() == var_node->name_hint(); + } + } + return matches; +} + +bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) { + return true; +} + +bool MatchPattern(DFPattern pattern, Expr expr) { + return DFPatternMatcher(expr).Match(pattern, expr); +} + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match").set_body_typed(MatchPattern); + +/* \brief PatternGrouper does pre-rewriting pattern matching and analysis + * + * This class creates a number of groups of matched expressions, ensures they don't overlap, and + * returns them to the caller for post-analysis rewriting. + * + * This is primarily needed to support the post-dominator analysis required for dominator pattern + * matching. + */ +class PatternGrouper : protected MixedModeVisitor { + public: + /* \brief Internal Group class for storing analysis */ + struct Group { + Expr root_node; + int gid; + Map> matched_nodes; + Function function; + Array args; + }; + + /* \brief Return the group assignments of expressions */ + const std::unordered_map& GetGIDAssignments() { + return gid_assignments_; + } + /* \brief Group expressions that match the pattern */ + const std::vector& GroupMatches(const DFPattern& pattern, const Expr& pre) { + groups_ = {Group()}; + gid_assignments_.clear(); + visit_counter_.clear(); + + pattern_ = pattern; + pattern_graph_ = CreateIndexedGraph(pattern_); + auto matcher = DFPatternMatcher(pre); + matcher_ = &matcher; + this->VisitExpr(pre); + return this->groups_; + } + + protected: + void VisitLeaf(const Expr& pre) override { + if (matcher_->Match(pattern_, pre)) { + CreateGroup(pre); + } + } + + /* \brief Creates a new set of nodes based on Group inputs, used to create functions and perform + * group overlap analysis */ + class MatchExtractor : public ExprMutator { + public: + explicit MatchExtractor(const std::unordered_map& inputs) + : inputs_(inputs) {} + const std::unordered_map& GetMemo() { return this->memo_; } + + protected: + Expr VisitExpr(const Expr& pre) override { + if (inputs_.count(pre)) { + return inputs_.at(pre); + } + return ExprMutator::VisitExpr(pre); + } + const std::unordered_map inputs_; + }; + + /* \brief Create a group based on a matched expression */ + void CreateGroup(const Expr& expr) { + int var_number = 0; + + auto node_map = matcher_->GetMemo(); + + // Get fuzzy patterns + std::unordered_set fuzzy_matches; + for (auto node : pattern_graph_.topological_order_) { + if (auto op = node->ref_.as()) { + for (auto fuzzy_op : {op->parent, op->path}) { + for (auto match : node_map[fuzzy_op]) { + fuzzy_matches.insert(match); + } + } + } + } + + // Create input variables + Group group; + group.root_node = expr; + group.matched_nodes = node_map; + + std::unordered_map inputs; + Array params; + for (auto node : pattern_graph_.topological_order_) { + if (node->inputs_.size() == 0) { + if (node_map.count(node->ref_)) { + auto matches = node_map[node->ref_]; + for (auto match : matches) { + if (fuzzy_matches.count(match) == 0 && match.as() == nullptr && + match.as() == nullptr && match.as() == nullptr) { + inputs[match] = Var( + "FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number), + NullValue()); + group.args.push_back(match); + params.push_back(inputs[match]); + var_number++; + } + } + } + } + } + + graph_number_++; + + // Extract a Function. Used in Partition directly, + // used to determine Group overlap in other passes + auto extractor = MatchExtractor(inputs); + auto body = extractor.Mutate(expr); + + // Verify the pattern still holds + CHECK(DFPatternMatcher(body).Match(pattern_, body)); + group.function = Function(params, body, NullValue(), Array()); + + // Check to make sure we aren't overlapping with another group + // The MatchExtractor will create a new graph by replacing nodes that match the inputs of the + // pattern with the input FunctionVar* Variables. The resulting memoization map will only + // contain nodes in the expression that matched the pattern. If a non-input node of the pattern + // (i.e., some piece of computation) overlaps with the nodes in a previous group, we'll have a + // situation where we try to rewrite the same node twice in the second rewriting or parition + // pass. This isn't valid, so we check for it here. We ignore Ops, functions, and constants + // because they exist more globally outside of the fusion. + for (auto kv : extractor.GetMemo()) { + if (gid_assignments_.count(kv.first) != 0 && inputs.count(kv.first) == 0 && + kv.first.as() == nullptr && kv.first.as() == nullptr && + kv.first.as() == nullptr) { + // Exit due to overlapping partitions + return; + } + } + // Assign Group Ids + group.gid = ++gid_; + for (auto kv : extractor.GetMemo()) { + gid_assignments_[kv.first] = gid_; + } + + // Save Group + groups_.emplace_back(std::move(group)); + CHECK_EQ(groups_[gid_].gid, gid_); + } + + // Internal State + DFPattern pattern_; + std::vector groups_; + std::unordered_map gid_assignments_; + DFPatternMatcher* matcher_ = nullptr; + IndexedGraph pattern_graph_; + int gid_ = 0; + int graph_number_ = 0; +}; + +// Rewrite + +DFPatternCallback::DFPatternCallback(DFPattern pattern, PackedFunc function) { + ObjectPtr n = make_object(); + n->pattern_ = std::move(pattern); + n->function_ = std::move(function); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(DFPatternCallbackNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DFPatternCallback") + .set_body_typed([](DFPattern pattern, PackedFunc function) { + return DFPatternCallback(pattern, function); + }); + +/* \brief PatternRewriter rewrites the expression by finding matches and allowing user callback + * function to rewrite those matches + * + * The class uses PatternGrouper to support the dominator pattern. + */ +class PatternRewriter : protected MixedModeMutator { + public: + PatternRewriter() {} + /*! \brief Rewrite can take a number of callbacks and will repeatedly rewrite the graph with the + * callbacks until it stops changing */ + Expr Rewrite(const Array& callbacks, const Expr& pre) { + auto post = pre; + auto last = post; + // rewrite the graph until it stops changing to make sure all rewrites are complete + int count = 0; + do { + last = post; + for (auto callback : callbacks) { + callback_ = callback; + auto grouper = PatternGrouper(); + groups_ = grouper.GroupMatches(callback_->pattern_, post); + gid_assignments_ = grouper.GetGIDAssignments(); + memo_.clear(); + post = this->VisitExpr(post); + count++; + } + } while (last != post || count >= 100); + if (count >= 100) { + throw("Observed 100 rewrite passes, possible conflicting passes?"); + } + return post; + } + + protected: + Expr DispatchVisitExpr(const Expr& pre) override { + auto post = MixedModeMutator::DispatchVisitExpr(pre); + if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node) { + // Convert the pre-rewrite node map to a post-rewrite node map + auto group = groups_[gid_assignments_[pre]]; + std::unordered_map, ObjectHash, ObjectEqual> node_map; + for (auto kv : group.matched_nodes) { + Array tmp; + for (size_t i = 0; i < kv.second.size(); ++i) { + tmp.push_back(this->memo_[kv.second[i]]); + } + node_map.insert({kv.first, tmp}); + } + // run the user callback function + return callback_->function_(pre, post, Map>(node_map)); + } + return post; + } + + DFPatternCallback callback_; + std::vector groups_; + std::unordered_map gid_assignments_; +}; + +Expr RewritePatterns(Array callbacks, Expr expr) { + return PatternRewriter().Rewrite(callbacks, expr); +} + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.rewrite").set_body_typed(RewritePatterns); + +/* \brief PatternPartitioner replaces expressions that match a pattern with function call that + * perform the same computation but allow for further analysis and lowering. + * + * The class uses PatternGrouper to support the dominator pattern. + */ +class PatternPartitioner : protected MixedModeMutator { + public: + Expr Partition(const DFPattern& pattern, const Expr& pre) { + auto grouper = PatternGrouper(); + groups_ = grouper.GroupMatches(pattern, pre); + gid_assignments_ = grouper.GetGIDAssignments(); + return this->VisitExpr(pre); + } + + protected: + Expr RewritePartition(const PatternGrouper::Group& group) { + Array args; + for (size_t i = 0; i < group.args.size(); ++i) { + args.push_back(memo_[group.args[i]]); + } + return Call(group.function, args); + } + + Expr DispatchVisitExpr(const Expr& pre) override { + auto post = MixedModeMutator::DispatchVisitExpr(pre); + if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node) { + post = RewritePartition(groups_[gid_assignments_[pre]]); + } + return post; + } + + std::vector groups_; + std::unordered_map gid_assignments_; +}; + +Expr PartitionPattern(DFPattern pattern, Expr expr) { + return PatternPartitioner().Partition(pattern, expr); +} + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.partition").set_body_typed(PartitionPattern); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/dataflow_pattern.cc b/src/relay/ir/dataflow_pattern.cc new file mode 100644 index 000000000000..826a035ca6ba --- /dev/null +++ b/src/relay/ir/dataflow_pattern.cc @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/relay/dataflow_pattern.cc + * \brief The dataflow pattern language for Relay. + */ +#include + +namespace tvm { +namespace relay { + +ExprPattern::ExprPattern(Expr expr) { + ObjectPtr n = make_object(); + n->expr = std::move(expr); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(ExprPatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ExprPattern").set_body_typed([](Expr e) { + return ExprPattern(e); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->Print(node->expr); + }); + +VarPattern::VarPattern(String name_hint, Type type_annotation) { + ObjectPtr n = make_object(); + n->name = std::move(name_hint); + n->type_annotation = std::move(type_annotation); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(VarPatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.VarPattern") + .set_body_typed([](String name_hint, Type type_annotation) { + return VarPattern(name_hint, type_annotation); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "VarPattern(" << node->name_hint(); + if (node->type_annotation.defined()) { + p->stream << ", ty="; + p->Print(node->type_annotation); + } + p->stream << ")"; + }); + +CallPattern::CallPattern(DFPattern op, Array args, Attrs attrs, Array type_args) { + ObjectPtr n = make_object(); + n->op = std::move(op); + n->args = std::move(args); + n->attrs = std::move(attrs); + n->type_args = std::move(type_args); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(CallPatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.CallPattern") + .set_body_typed([](DFPattern op, Array args, Attrs attrs, Array type_args) { + return CallPattern(op, args, attrs, type_args); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "CallPatternNode(" << node->op << ", " << node->args << ", " << node->attrs + << ", " << node->type_args << ")"; + }); + +TuplePattern::TuplePattern(tvm::Array fields) { + ObjectPtr n = make_object(); + n->fields = std::move(fields); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TuplePatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.TuplePattern") + .set_body_typed([](tvm::Array fields) { return TuplePattern(fields); }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TuplePattern(" << node->fields << ")"; + }); + +TupleGetItemPattern::TupleGetItemPattern(DFPattern tuple, int index) { + ObjectPtr n = make_object(); + n->tuple = std::move(tuple); + n->index = index; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TupleGetItemPatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.TupleGetItemPattern") + .set_body_typed([](DFPattern tuple, int index) { return TupleGetItemPattern(tuple, index); }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TupleGetItemPatternNode(" << node->tuple << ", " << node->index << ")"; + }); + +AltPattern::AltPattern(DFPattern left, DFPattern right) { + ObjectPtr n = make_object(); + n->left = std::move(left); + n->right = std::move(right); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(AltPatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.AltPattern") + .set_body_typed([](DFPattern left, DFPattern right) { return AltPattern(left, right); }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "AltPattern(" << node->left << " | " << node->right << ")"; + }); + +TVM_REGISTER_NODE_TYPE(WildcardPatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.WildcardPattern").set_body_typed([]() { + auto w = WildcardPattern(make_object()); + return w; +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + p->stream << "*"; + }); + +TypePattern::TypePattern(DFPattern pattern, Type type) { + ObjectPtr n = make_object(); + n->pattern = std::move(pattern); + n->type = std::move(type); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TypePatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.TypePattern") + .set_body_typed([](DFPattern pattern, Type type) { return TypePattern(pattern, type); }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TypePattern(" << node->pattern << " has type " << node->type << ")"; + }); + +AttrPattern::AttrPattern(DFPattern pattern, Attrs attrs) { + ObjectPtr n = make_object(); + n->pattern = std::move(pattern); + n->attrs = std::move(attrs); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(AttrPatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.AttrPattern") + .set_body_typed([](DFPattern pattern, Attrs attrs) { return AttrPattern(pattern, attrs); }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "AttrPattern(" << node->pattern << " has attributes " << node->attrs << ")"; + }); + +DominatorPattern::DominatorPattern(DFPattern parent, DFPattern path, DFPattern child) { + ObjectPtr n = make_object(); + n->parent = std::move(parent); + n->path = std::move(path); + n->child = std::move(child); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(DominatorPatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DominatorPattern") + .set_body_typed([](DFPattern parent, DFPattern path, DFPattern child) { + return DominatorPattern(parent, path, child); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "DominatorPattern(" << node->parent << ", " << node->path << ", " << node->child + << ")"; + }); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/dataflow_pattern_functor.cc b/src/relay/ir/dataflow_pattern_functor.cc new file mode 100644 index 000000000000..c7c34c804449 --- /dev/null +++ b/src/relay/ir/dataflow_pattern_functor.cc @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/relay/dataflow_matcher.cc + * \brief The dataflow pattern matcher for Relay. + */ + +#include + +namespace tvm { +namespace relay { + +// DFPatternVisitor + +void DFPatternVisitor::VisitDFPattern(const DFPattern& pattern) { + if (this->visited_.count(pattern.get()) == 0) { + visited_.insert(pattern.get()); + DFPatternFunctor::VisitDFPattern(pattern); + } +} + +void DFPatternVisitor::VisitDFPattern_(const AltPatternNode* op) { + VisitDFPattern(op->left); + VisitDFPattern(op->right); +} + +void DFPatternVisitor::VisitDFPattern_(const AttrPatternNode* op) { VisitDFPattern(op->pattern); } + +void DFPatternVisitor::VisitDFPattern_(const CallPatternNode* op) { + VisitDFPattern(op->op); + for (auto arg : op->args) { + VisitDFPattern(arg); + } +} +void DFPatternVisitor::VisitDFPattern_(const DominatorPatternNode* op) { + VisitDFPattern(op->parent); + VisitDFPattern(op->path); + VisitDFPattern(op->child); +} + +void DFPatternVisitor::VisitDFPattern_(const ExprPatternNode* op) {} + +void DFPatternVisitor::VisitDFPattern_(const TupleGetItemPatternNode* op) { + VisitDFPattern(op->tuple); +} + +void DFPatternVisitor::VisitDFPattern_(const TuplePatternNode* op) { + for (auto field : op->fields) { + VisitDFPattern(field); + } +} + +void DFPatternVisitor::VisitDFPattern_(const TypePatternNode* op) { VisitDFPattern(op->pattern); } + +void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {} + +void DFPatternVisitor::VisitDFPattern_(const WildcardPatternNode* op) {} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 18fd1c711dd0..684dae7cc481 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -142,7 +142,8 @@ void MixedModeVisitor::VisitExpr_(const TupleGetItemNode* op) {} void MixedModeMutator::VisitLeaf(const Expr& expr) { if (!memo_.count(expr)) { - this->DispatchVisitExpr(expr); + Expr ret = this->DispatchVisitExpr(expr); + memo_[expr] = ret; } } @@ -163,9 +164,7 @@ Expr MixedModeMutator::VisitExpr(const Expr& expr) { return memo_[expr]; } else { ExpandDataflow(expr, fcheck_visited, fvisit_leaf); - Expr ret = this->DispatchVisitExpr(expr); - memo_[expr] = ret; - return ret; + return memo_[expr]; } } diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc new file mode 100644 index 000000000000..79ec57426d66 --- /dev/null +++ b/src/relay/ir/indexed_graph.cc @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/ir/indexed_graph.cc + * \brief Utilties for Creating Indexed Graphs. + */ +#include "indexed_graph.h" + +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +// IndexedGraph + +IndexedGraph CreateIndexedGraph(const Expr& expr) { + using NodePtr = std::shared_ptr::Node>; + /*! \brief Creator Creates an IndexedGraph and determintes Topological order */ + class Creator : public MixedModeVisitor { + public: + IndexedGraph CreateGraph(const Expr& expr) { + VisitExpr(expr); + graph_.node_map_[expr]->is_external_ = true; + return std::move(graph_); + } + + protected: + void VisitLeaf(const Expr& expr) override { + MixedModeVisitor::VisitLeaf(expr); + auto node = std::make_shared::Node>(expr, index_++); + graph_.node_map_[expr] = node; + graph_.topological_order_.push_back(node); + } + IndexedGraph graph_; + size_t index_ = 0; + }; + /*! \brief Annotator takes an IndexedGraph, fills it's forward outputs, and does dominator tree + * analysis. + * + * Annotator use ExprFunctor to visit nodes, but iterates over them in pre-determined + * topological order instead of recursing. + */ + class Annotator : public ExprFunctor { + public: + Annotator(const IndexedGraph& graph) : graph_(graph) {} + IndexedGraph Annotate() { + // Visit all of the nodes in topological order to get forward outputs + for (const auto& node : graph_.topological_order_) { + ExprFunctor::VisitExpr(node->ref_, nullptr); + } + // do the dominator analysis + graph_.PostDom(); + return std::move(graph_); + } + + /*! Default visitation pushes the parent to the child's ouputs and the child to the parent's + * inputs*/ + void VisitExpr(const Expr& expr, NodePtr parent) override { + auto current = graph_.node_map_[expr]; + if (parent) { + current->outputs_.push_back(parent.get()); + parent->inputs_.push_back(current.get()); + } + } + + protected: + IndexedGraph graph_; + void VisitExpr_(const VarNode* op, NodePtr parent) override { + if (op->type_annotation.defined()) { + this->VisitType(op->type_annotation); + } + } + + void VisitExpr_(const GlobalVarNode* op, NodePtr parent) override {} + + void VisitExpr_(const ConstantNode* op, NodePtr parent) override {} + + void VisitExpr_(const TupleNode* op, NodePtr parent) override { + for (auto field : op->fields) { + this->VisitExpr(field, graph_.node_map_[GetRef(op)]); + } + } + + void VisitExpr_(const FunctionNode* op, NodePtr parent) override { + for (auto param : op->params) { + this->VisitExpr(param, graph_.node_map_[GetRef(op)]); + } + + this->VisitExpr(op->body, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const CallNode* op, NodePtr parent) override { + this->VisitExpr(op->op, graph_.node_map_[GetRef(op)]); + + for (auto ty_arg : op->type_args) { + this->VisitType(ty_arg); + } + + for (auto arg : op->args) { + this->VisitExpr(arg, graph_.node_map_[GetRef(op)]); + } + } + + void VisitExpr_(const LetNode* op, NodePtr parent) override { + this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->var, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->body, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const IfNode* op, NodePtr parent) override { + this->VisitExpr(op->cond, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->true_branch, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->false_branch, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const OpNode* op, NodePtr parent) override { return; } + + void VisitExpr_(const TupleGetItemNode* op, NodePtr parent) override { + this->VisitExpr(op->tuple, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const RefCreateNode* op, NodePtr parent) override { + this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const RefReadNode* op, NodePtr parent) override { + this->VisitExpr(op->ref, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const RefWriteNode* op, NodePtr parent) override { + this->VisitExpr(op->ref, graph_.node_map_[GetRef(op)]); + this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); + } + + void VisitExpr_(const ConstructorNode* op, NodePtr parent) override { + for (const Type& t : op->inputs) { + this->VisitType(t); + } + this->VisitType(op->belong_to); + } + + void VisitExpr_(const MatchNode* op, NodePtr parent) override { + this->VisitExpr(op->data, graph_.node_map_[GetRef(op)]); + for (const Clause& c : op->clauses) { + this->VisitClause(c, graph_.node_map_[GetRef(op)]); + } + } + + void VisitClause(const Clause& op, NodePtr parent) { + this->VisitPattern(op->lhs); + this->VisitExpr(op->rhs, parent); + } + + void VisitPattern(const Pattern& p) { return; } + + void VisitType(const Type& t) { return; } + }; + return Annotator(Creator().CreateGraph(expr)).Annotate(); +} + +IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { + using NodePtr = std::shared_ptr::Node>; + /*! \brief Creator Creates an IndexedGraph and determintes Toplogical order */ + class Creator : public DFPatternVisitor { + public: + IndexedGraph CreateGraph(const DFPattern& pattern) { + VisitDFPattern(pattern); + graph_.node_map_[pattern]->is_external_ = true; + return std::move(graph_); + } + + protected: + void VisitDFPattern(const DFPattern& pattern) override { + DFPatternVisitor::VisitDFPattern(pattern); + auto node = std::make_shared::Node>(pattern, index_++); + graph_.node_map_[pattern] = node; + graph_.topological_order_.push_back(node); + } + IndexedGraph graph_; + size_t index_ = 0; + }; + /*! \brief Annotator takes an IndexedGraph, fills it's forward outputs, and does domiantor tree + * analysis. + * + * Annotator use ExprFunctor to visit nodes, but iterates over them in pre-determined + * topological order instead of recursing. + */ + class Annotator : public DFPatternFunctor { + public: + Annotator(const IndexedGraph& graph) : graph_(graph) {} + IndexedGraph Annotate() { + // Visit all of the nodes in topological order to get forward outputs + for (const auto& node : graph_.topological_order_) { + DFPatternFunctor::VisitDFPattern(node->ref_, nullptr); + } + graph_.PostDom(); + // do the dominator analysis + return std::move(graph_); + } + + /*! Default visitation pushes the parent to the child's ouputs */ + void VisitDFPattern(const DFPattern& pattern, NodePtr parent) override { + auto current = graph_.node_map_[pattern]; + if (parent) { + current->outputs_.push_back(parent.get()); + parent->inputs_.push_back(current.get()); + } + } + + protected: + IndexedGraph graph_; + void VisitDFPattern_(const AltPatternNode* op, NodePtr parent) override { + VisitDFPattern(op->left, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->right, graph_.node_map_[GetRef(op)]); + } + + void VisitDFPattern_(const AttrPatternNode* op, NodePtr parent) override { + VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + } + + void VisitDFPattern_(const CallPatternNode* op, NodePtr parent) override { + VisitDFPattern(op->op, graph_.node_map_[GetRef(op)]); + for (auto arg : op->args) { + VisitDFPattern(arg, graph_.node_map_[GetRef(op)]); + } + } + void VisitDFPattern_(const DominatorPatternNode* op, NodePtr parent) override { + VisitDFPattern(op->parent, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->path, graph_.node_map_[GetRef(op)]); + VisitDFPattern(op->child, graph_.node_map_[GetRef(op)]); + } + + void VisitDFPattern_(const ExprPatternNode* op, NodePtr parent) override {} + + void VisitDFPattern_(const TupleGetItemPatternNode* op, NodePtr parent) override { + VisitDFPattern(op->tuple, graph_.node_map_[GetRef(op)]); + } + + void VisitDFPattern_(const TuplePatternNode* op, NodePtr parent) override { + for (auto field : op->fields) { + VisitDFPattern(field, graph_.node_map_[GetRef(op)]); + } + } + + void VisitDFPattern_(const TypePatternNode* op, NodePtr parent) override { + VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + } + + void VisitDFPattern_(const VarPatternNode* op, NodePtr parent) override {} + + void VisitDFPattern_(const WildcardPatternNode* op, NodePtr parent) override {} + }; + return Annotator(Creator().CreateGraph(pattern)).Annotate(); +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/ir/indexed_graph.h b/src/relay/ir/indexed_graph.h new file mode 100644 index 000000000000..d2524340f971 --- /dev/null +++ b/src/relay/ir/indexed_graph.h @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/ir/indexed_graph.h + * \brief A pattern matcher for matching dataflow properties. + */ +#ifndef TVM_RELAY_IR_INDEXED_GRAPH_H_ +#define TVM_RELAY_IR_INDEXED_GRAPH_H_ + +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +/*! + * \brief A Wrapper around a templated graph type + * Holds a forward-backward indexed representation of the graph and a dominator tree representation + * of the graph + * + * This class is templated and the implementaiton is in the header file so we can analyze both + * DFPattern and Expr with the same infrastructure. + * + * IndexedGraph should be instantiated through the CreateIndexedGraph utilities. + */ +template +class IndexedGraph { + public: + /*! \brief A Node that wraps the input type and represents the indexed graph and dominator tree */ + struct Node { + /*! \brief Node Constructor + * \param ref The input graph node + * \param index The index of the node in toplogical order + */ + Node(const T& ref, const size_t index) : ref_(ref), index_(index) {} + + /*! \brief The input node */ + const T ref_; + /*! \brief The topological order index */ + const size_t index_; + + /*! \brief A boolean to determine if this node is external to the graph */ + bool is_external_ = false; + /*! \brief The forward inputs of the node */ + std::vector inputs_; + /*! \brief The forward outputs/users of the node */ + std::vector outputs_; + + /*! \brief The depth of the node in the dominator tree */ + size_t depth_; + /*! \brief The dominator parent/final user of the outputs of this node */ + Node* dominator_parent_; + /*! \brief The nodes this node dominates */ + std::vector dominator_children_; + }; + /*! \brief Construct the domination tree inside IndexedGraph */ + void PostDom() { + for (size_t i = topological_order_.size(); i != 0; --i) { + size_t index = i - 1; + auto* current = topological_order_[index].get(); + if (current->is_external_) { + current->depth_ = 1; + current->dominator_parent_ = nullptr; + } else { + auto parent = LeastCommonAncestor(current->outputs_); + current->depth_ = parent ? parent->depth_ + 1 : 1; + current->dominator_parent_ = parent; + parent->dominator_children_.push_back(current); + } + } + } + /*! \brief Map of input nodes to IndexedGraph Nodes */ + std::unordered_map, ObjectHash, ObjectEqual> node_map_; + /*! \brief Topological IndexedGraph Nodes */ + std::vector> topological_order_; + + protected: + /*! \brief Find the least common ancestor of all outputs of a node */ + Node* LeastCommonAncestor(const std::vector& outputs) { + if (outputs.size() == 0) { + return nullptr; + } + auto parent = outputs.at(0); + for (size_t i = 1; i < outputs.size(); ++i) { + parent = LeastCommonAncestor(parent, outputs.at(i)); + } + return parent; + } + + /*! \brief Find the least common ancestor of two nodes */ + Node* LeastCommonAncestor(Node* lhs, Node* rhs) { + if (lhs == nullptr || rhs == nullptr) { + return nullptr; + } + while (lhs != rhs) { + if (lhs->depth_ < rhs->depth_) { + rhs = rhs->dominator_parent_; + } else if (lhs->depth_ > rhs->depth_) { + lhs = lhs->dominator_parent_; + } else { + rhs = rhs->dominator_parent_; + lhs = lhs->dominator_parent_; + } + } + return lhs; + } +}; + +/*! \brief Create an Indexed Graph based on an Expr */ +IndexedGraph CreateIndexedGraph(const Expr& expr); +/*! \brief Create an Indexed Graph based on an DFPattern */ +IndexedGraph CreateIndexedGraph(const DFPattern& pattern); + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_IR_INDEXED_GRAPH_H_ diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py new file mode 100644 index 000000000000..a93a39be14d0 --- /dev/null +++ b/tests/python/relay/test_dataflow_pattern.py @@ -0,0 +1,853 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relay +from tvm.relay.dataflow_pattern import * +import numpy as np + +# NB: 1 corresponds to the C++ enum that specicfies this +# we loose the type safety due to the Python/C++ calling +# convention. +K_ELEMWISE = 0 +K_BROADCAST = 1 + +## NODE TESTS +def test_expr_pattern(): + ep = ExprPattern(relay.var('x', shape=(4, 1))) + assert isinstance(ep, ExprPattern) + assert isinstance(ep.expr, relay.Var) + +def test_var_pattern(): + v = is_input("x") + assert isinstance(v, VarPattern) + assert v.name == "x" + +def test_wildcard_pattern(): + wc = wildcard() + assert isinstance(wc, WildcardPattern) + +def test_CallPattern(): + wc1 = wildcard() + wc2 = wildcard() + c = is_op("add")(wc1, wc2) + assert isinstance(c, CallPattern) + assert isinstance(c.args[0], WildcardPattern) + assert isinstance(c.args[1], WildcardPattern) + +def test_TuplePattern(): + wc1 = wildcard() + wc2 = wildcard() + t = TuplePattern([wc1, wc2]) + assert isinstance(t, TuplePattern) + assert isinstance(t.fields[0], WildcardPattern) + assert isinstance(t.fields[1], WildcardPattern) + +def test_TupleGetItemPattern(): + wc1 = wildcard() + wc2 = wildcard() + t = TuplePattern([wc1, wc2]) + tgi = TupleGetItemPattern(t, 1) + assert isinstance(tgi, TupleGetItemPattern) + assert isinstance(tgi.tuple, TuplePattern) + assert isinstance(tgi.tuple.fields[0], WildcardPattern) + assert isinstance(tgi.tuple.fields[1], WildcardPattern) + +def test_AltPattern(): + is_add_or_sub = is_op('add') | is_op('subtract') + assert isinstance(is_add_or_sub, AltPattern) + +def test_TypePattern(): + ttype = relay.TensorType((10, 10), "float32") + ty_pat = has_type(ttype) + assert isinstance(ty_pat, TypePattern) + assert ty_pat.type == ttype + +def test_AttrPattern(): + op = is_op('add').has_attr("TOpPattern", K_ELEMWISE) + assert isinstance(op, AttrPattern) + assert op.attrs["TOpPattern"] == K_ELEMWISE + +## MATCHER TESTS + +def test_match_op(): + assert is_op('add').match(relay.op.op.get("add")) + +def test_no_match_op(): + assert not is_op('add').match(relay.op.op.get("subtract")) + +def test_match_op_or(): + is_add_or_sub = is_op('add') | is_op('subtract') + assert is_add_or_sub.match(relay.op.op.get("add")) + assert is_add_or_sub.match(relay.op.op.get("subtract")) + +def test_match_call_commutive(): + x = relay.var('x') + y = relay.var('y') + add_pattern = is_op('add')(is_input("x"), is_input("y")) + assert add_pattern.match(x + y) + assert add_pattern.match(y + x) + mul_pattern = is_op('multiply')(is_input("x"), is_input("y")) + assert mul_pattern.match(x * y) + assert mul_pattern.match(y * x) + +def test_no_match_call_commutive(): + x = relay.var('x') + y = relay.var('y') + add_pattern = is_op('subtract')(is_input("x"), is_input("y")) + assert add_pattern.match(x - y) + assert not add_pattern.match(y - x) + add_pattern = is_op('divide')(is_input("x"), is_input("y")) + assert add_pattern.match(x / y) + assert not add_pattern.match(y / x) + +def test_match_call(): + x = relay.var('x') + y = relay.var('y') + add_pattern = is_op('add')(wildcard(), wildcard()) + assert add_pattern.match(x + y) + +def test_no_match_call(): + x = relay.var('x') + y = relay.var('y') + add_pattern = is_op('add')(wildcard(), wildcard()) + assert not add_pattern.match(x - y) + +def test_match_option(): + x = relay.var('x') + w = relay.var('w') + b = relay.var('b') + pattern = is_op("nn.relu")( + is_op("nn.conv2d")(wildcard(), wildcard() + ).optional(lambda x: is_op("nn.bias_add")(x, wildcard())) + ) + + conv2d = relay.op.nn.conv2d(x, w) + relu = relay.op.nn.relu(conv2d) + assert pattern.match(relu) + + conv2d = relay.op.nn.conv2d(x, w) + bias_add = relay.op.nn.bias_add(conv2d, b) + relu = relay.op.nn.relu(bias_add) + assert pattern.match(relu) + + pattern = is_op("nn.conv2d")(wildcard(), wildcard()) + pattern = pattern.optional(is_op('nn.relu')).optional(is_op("tanh")) + + conv2d = relay.op.nn.conv2d(x, w) + relu = relay.op.nn.relu(conv2d) + tanh = relay.op.tanh(conv2d) + tanh2 = relay.op.tanh(relu) + relu2 = relay.op.nn.relu(tanh) + assert pattern.match(conv2d) + assert pattern.match(relu) + assert pattern.match(tanh) + assert pattern.match(tanh2) + assert not pattern.match(relu2) + +def test_no_match_option(): + x = relay.var('x') + w = relay.var('w') + b = relay.var('b') + pattern = is_op("nn.relu")( + is_op("nn.conv2d")(wildcard(), wildcard() + ).optional(lambda x: is_op("nn.bias_add")(x, wildcard())) + ) + + conv2d = relay.op.nn.conv2d(x, w) + relu = relay.op.tanh(conv2d) + assert not pattern.match(relu) + + conv2d = relay.op.nn.dense(x, w) + relu = relay.op.tanh(conv2d) + assert not pattern.match(relu) + + conv2d = relay.op.nn.dense(x, w) + bias_add = relay.op.nn.bias_add(conv2d, b) + relu = relay.op.nn.relu(bias_add) + assert not pattern.match(relu) + + conv2d = relay.op.nn.conv2d(x, w) + bias_add = conv2d + w + relu = relay.op.nn.relu(bias_add) + assert not pattern.match(relu) + +def test_match_tuple(): + x = relay.var('x') + y = relay.var('y') + z = relay.op.op.get("add") + tuple_pattern = TuplePattern((is_input("x"), wildcard(), is_op("add"))) + assert tuple_pattern.match(relay.expr.Tuple((x,y,z))) + +def test_no_match_tuple(): + x = relay.var('x') + y = relay.var('y') + z = relay.op.op.get("add") + tuple_pattern = TuplePattern((is_input('x'), wildcard(), is_op("add"), wildcard())) + assert not tuple_pattern.match(relay.expr.Tuple((x,y,z))) + +def test_match_tuple(): + x = relay.var('x') + y = relay.var('y') + z = relay.op.op.get("add") + tuple_pattern = TuplePattern((is_input("x"), wildcard(), is_op("add"))) + tuple_get_item_pattern = TupleGetItemPattern(tuple_pattern, 1) + assert tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x,y,z)), 1)) + +def test_no_match_tuple(): + x = relay.var('x') + y = relay.var('y') + z = relay.op.op.get("add") + tuple_pattern = TuplePattern((is_input('x'), wildcard(), is_op("add"))) + tuple_get_item_pattern = TupleGetItemPattern(tuple_pattern, 1) + assert not tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x,y,z)), 2)) + +def test_match_type(): + x = relay.var('x', shape=(10, 10), dtype="float32") + ty_pat = has_type(relay.TensorType((10, 10), "float32")) + assert ty_pat.match(x) + +def test_no_match_type(): + x = relay.var('x', shape=(10, 10), dtype="int32") + ty_pat = has_type(relay.TensorType((10, 10), "float32")) + assert not ty_pat.match(x) + +def test_match_attr(): + op = is_op('add').has_attr("TOpPattern", K_BROADCAST) + op_pat = op(wildcard(), wildcard()) + x = relay.var('x') + y = relay.var('y') + assert op_pat.match(x + y) + +def test_no_match_attr(): + op = is_op('nn.dense').has_attr("TOpPattern", K_ELEMWISE) + op_pat = op(wildcard(), wildcard()) + x = relay.var('x') + y = relay.var('y') + assert not op_pat.match(relay.op.nn.dense(x, y)) + +def test_match_diamond(): + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + path1 = is_op('nn.relu')(is_conv2d) + path2 = is_op('nn.leaky_relu')(is_conv2d) + diamond = is_op('add')(path1, path2) + + # Expr + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Check + assert diamond.match(out) + +def test_no_match_diamond(): + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + path1 = is_op('nn.relu')(is_conv2d) + path2 = is_op('nn.leaky_relu')(is_conv2d) + diamond = is_op('add')(path1, path2) + + # Expr + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Check + assert not diamond.match(leaky_relu) + assert not diamond.match(relu) + +def test_match_fake_diamond(): + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + path1 = is_op('nn.relu')(is_conv2d) + path2 = is_op('nn.leaky_relu')(is_conv2d) + diamond = is_op('add')(path1, path2) + + # Expr + input1 = relay.var('input1') + weight1 = relay.var('weight1') + conv2d1 = relay.op.nn.conv2d(input1, weight1) + inp2 = relay.var('input2') + weight2 = relay.var('weight2') + conv2d2 = relay.op.nn.conv2d(inp2, weight2) + relu = relay.op.nn.relu(conv2d1) + leaky_relu = relay.op.nn.leaky_relu(conv2d2, alpha=0) + out = relu + leaky_relu + + # Check + assert not diamond.match(out) + + +def test_match_dominator(): + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) + reduction = is_op('add')(wildcard(), wildcard()) + diamond = dominates(is_conv2d, is_unary_elemwise, reduction) + + # Classic Diamond + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Check + assert diamond.match(out) + + # Deeper Branch + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + relu = relay.op.tanh(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Check + assert diamond.match(out) + + # Single Branch + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + tanh = relay.op.tanh(relu) + out = relu + tanh + + # Check + assert diamond.match(out) + + # Fuzzy path/nested Diamond + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) | is_op('add')(wildcard(), wildcard()) + reduction = is_op('add')(wildcard(), wildcard()) + diamond = dominates(is_conv2d, is_unary_elemwise, reduction) + + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relu + relu + tanh = relay.op.tanh(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = tanh + leaky_relu + + assert diamond.match(out) + +def test_not_match_dominator(): + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) + reduction = is_op('add')(wildcard(), wildcard()) + diamond = dominates(is_conv2d, is_unary_elemwise, reduction) + + # Fake Diamond + input1 = relay.var('input1') + weight1 = relay.var('weight1') + conv2d1 = relay.op.nn.conv2d(input1, weight1) + inp2 = relay.var('input2') + weight2 = relay.var('weight2') + conv2d2 = relay.op.nn.conv2d(inp2, weight2) + relu = relay.op.nn.relu(conv2d1) + leaky_relu = relay.op.nn.leaky_relu(conv2d2, alpha=0) + out = relu + leaky_relu + + # Check + assert not diamond.match(out) + + # Add op that doesn't match K_ELEMWISE + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relu + relu + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Check + assert not diamond.match(out) + + # Relu on the input instead of the conv + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(inp) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Check + assert not diamond.match(out) + + # No conv + inp = relay.var('input') + relu = relay.op.nn.relu(inp) + relu = relay.op.nn.relu(relu) + tanh = relay.op.tanh(relu) + out = relu + tanh + + # Check + assert not diamond.match(out) + +def test_rewrite(): + x = relay.var('x') + y = relay.var('y') + add_pattern = is_op('add')(wildcard(), wildcard()) + sub_pattern = is_op('subtract')(wildcard(), wildcard()) + class TestRewrite(DFPatternCallback): + def __init__(self): + self.pattern = add_pattern + def callback(self, pre, post, node_map): + return post.args[0] - post.args[1] + out = rewrite(TestRewrite(), x + y) + assert sub_pattern.match(out) + +def test_not_fuse_multi_diamond(): + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + path1 = is_op('nn.relu')(is_conv2d) + path2 = is_op('nn.leaky_relu')(is_conv2d) + diamond = is_op('add')(path1, path2) + + # Expr + inp = relay.var('input') + weight = relay.var('weight') + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + out = out + conv2d + # Check + assert not diamond.match(out) + +class BatchnormCallback(DFPatternCallback): + def __init__(self): + self.x = wildcard() + self.var = wildcard() + self.mean = wildcard() + self.beta = wildcard() + self.gamma = wildcard() + self.eps = wildcard() + + self.pattern = self.gamma * (self.x - self.mean)/is_op("sqrt")(self.var + self.eps) + self.beta + + def callback(self, pre, post, node_map): + x = node_map[self.x][0] + var = node_map[self.var][0] + mean = node_map[self.mean][0] + beta = node_map[self.beta][0] + gamma = node_map[self.gamma][0] + eps = node_map[self.eps][0] + return relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = eps.data.asnumpy().item())[0] + +def test_fuse_batchnorm(): + x = relay.var('x') + var = relay.var('var') + mean = relay.var('mean') + beta = relay.var('beta') + gamma = relay.var('gamma') + + BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta + + out = rewrite(BatchnormCallback(), BN) + assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0]) + +def test_no_fuse_batchnorm(): + x = relay.var('x') + var = relay.var('var') + mean = relay.var('mean') + beta = relay.var('beta') + gamma = relay.var('gamma') + + fake_BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) - beta + + out = rewrite(BatchnormCallback(), fake_BN) + assert tvm.ir.structural_equal(out, fake_BN) + +def test_fuse_double_batchnorm(): + x = relay.var('x') + var = relay.var('var') + mean = relay.var('mean') + beta = relay.var('beta') + gamma = relay.var('gamma') + + BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta + BN2 = gamma * (BN - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta + + out = rewrite(BatchnormCallback(), BN2) + + bn = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0] + bn2 = relay.op.nn.batch_norm(bn, gamma, beta, mean, var, epsilon = 1e-5)[0] + + assert tvm.ir.structural_equal(out, bn2) + +def test_partial_fuse_double_batchnorm(): + x = relay.var('x') + var = relay.var('var') + mean = relay.var('mean') + beta = relay.var('beta') + gamma = relay.var('gamma') + + BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) - beta + BN2 = gamma * (BN - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta + + out = rewrite(BatchnormCallback(), BN2) + + bn2 = relay.op.nn.batch_norm(BN, gamma, beta, mean, var, epsilon = 1e-5)[0] + + assert tvm.ir.structural_equal(out, bn2) + +def test_fuse_batchnorm_commutation(): + x = relay.var('x') + var = relay.var('var') + mean = relay.var('mean') + beta = relay.var('beta') + gamma = relay.var('gamma') + + #commute add + BN = beta + gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + out = rewrite(BatchnormCallback(), BN) + assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0]) + + # associate divide/multiply + BN = (gamma * (x - mean)) /relay.op.sqrt(var + relay.const(1e-5)) + beta + out = rewrite(BatchnormCallback(), BN) + assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0]) + + # associate multiply/divide + BN = gamma * ((x - mean)/relay.op.sqrt(var + relay.const(1e-5))) + beta + out = rewrite(BatchnormCallback(), BN) + assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = 1e-5)[0]) + +def test_quadruple_rewrite_dominator(): + class DominatorRemovalCallback(DFPatternCallback): + def __init__(self): + self.inp = wildcard() + self.weight = wildcard() + + is_conv2d = is_op('nn.conv2d')(self.inp, self.weight) + is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) | is_op('add')(wildcard(), wildcard()) + reduction = is_op('add')(wildcard(), wildcard()) + self.pattern = dominates(is_conv2d, is_unary_elemwise, reduction) + + def callback(self, pre, post, node_map): + inp = node_map[self.inp][0] + weight = node_map[self.weight][0] + return relay.op.nn.conv2d(inp, weight) + + inp = relay.var('input') + weight = relay.var('weight') + # Classic Diamond + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Deeper Branch + conv2d = relay.op.nn.conv2d(out, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + relu = relay.op.tanh(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Single Branch + conv2d = relay.op.nn.conv2d(out, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + tanh = relay.op.tanh(relu) + out = relu + tanh + + # Fuzzy path/nested Diamond + conv2d = relay.op.nn.conv2d(out, weight) + relu = relay.op.nn.relu(conv2d) + relu = relu + relu + tanh = relay.op.tanh(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = tanh + leaky_relu + + one = relay.op.nn.conv2d(inp, weight) + two = relay.op.nn.conv2d(one, weight) + three = relay.op.nn.conv2d(two, weight) + four = relay.op.nn.conv2d(three, weight) + + assert tvm.ir.structural_equal(DominatorRemovalCallback().rewrite(out), four) + +def algebraic_simplify(expr): + zero = (ExprPattern(relay.const(0)) | ExprPattern(relay.const(0.0))) + one = (ExprPattern(relay.const(1)) | ExprPattern(relay.const(1.0))) + class ElwiseNullCallback(DFPatternCallback): + def callback(self, pre, post, node_map): + return node_map[self.x][0] + + class AddCallback(ElwiseNullCallback): + def __init__(self): + self.x = wildcard() + self.pattern = self.x + zero + + class SubCallback(ElwiseNullCallback): + def __init__(self): + self.x = wildcard() + self.pattern = self.x - zero + + class MulCallback(ElwiseNullCallback): + def __init__(self): + self.x = wildcard() + self.pattern = self.x * one + + class DivCallback(ElwiseNullCallback): + def __init__(self): + self.x = wildcard() + self.pattern = self.x / one + + class MulZeroCallback(ElwiseNullCallback): + def __init__(self): + self.x = zero + self.pattern = self.x * wildcard() + + class ZeroDivCallback(ElwiseNullCallback): + def __init__(self): + self.x = zero + self.pattern = self.x / wildcard() + + return rewrite([AddCallback(), + SubCallback(), + MulCallback(), + DivCallback(), + MulZeroCallback(), + ZeroDivCallback() + ], expr); + +def test_algebraic_simplify(): + x = relay.Var('x') + y = relay.Var('y') + + one = relay.const(1) + zero = relay.const(0) + onef = relay.const(1.0) + zerof = relay.const(0.0) + + assert algebraic_simplify(x + zero) == x + assert algebraic_simplify(x + zerof) == x + assert algebraic_simplify(zero + x) == x + assert algebraic_simplify(zerof + x) == x + + assert algebraic_simplify(x - zero) == x + assert algebraic_simplify(x - zerof) == x + + assert algebraic_simplify(x * one) == x + assert algebraic_simplify(x * onef) == x + assert algebraic_simplify(one * x) == x + assert algebraic_simplify(onef * x) == x + assert algebraic_simplify(x * zero) == zero + assert algebraic_simplify(x * zerof) == zerof + + assert algebraic_simplify(x / one) == x + assert algebraic_simplify(x / onef) == x + assert algebraic_simplify(zero / x) == zero + assert algebraic_simplify(zerof / x) == zerof + + assert tvm.ir.structural_equal(algebraic_simplify((x + zero * y) / one + (y * one) - zero / x), x + y) + +def test_partition_dominator(): + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) + reduction = is_op('add')(wildcard(), wildcard()) + diamond = dominates(is_conv2d, is_unary_elemwise, reduction) + + # Classic Diamond + inp = relay.var('input') + weight = relay.var('weight') + def generate_diamond(inp, weight): + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + return relu + leaky_relu + out = generate_diamond(inp*inp, weight*weight) + # Check + partitioned = diamond.partition(out) + + i = relay.Var("input") + w = relay.Var("weight") + f = relay.Function([i, w], generate_diamond(i, w)) + assert tvm.ir.structural_equal(partitioned, f(inp*inp, weight*weight)) + +def test_quadruple_partition_dominator(): + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) | is_op('add')(wildcard(), wildcard()) + reduction = is_op('add')(wildcard(), wildcard()) + diamond = dominates(is_conv2d, is_unary_elemwise, reduction) + + + inp = relay.var('input') + weight = relay.var('weight') + # Classic Diamond + def classic_diamond(inp, weight): + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + return relu + leaky_relu + + # Deeper Branch + def deeper_diamond(inp, weight): + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + relu = relay.op.tanh(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + return relu + leaky_relu + + # Single Branch + def single_branch(inp, weight): + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + tanh = relay.op.tanh(relu) + return relu + tanh + + # Fuzzy path/nested Diamond + def nested_diamond(inp, weight): + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relu + relu + tanh = relay.op.tanh(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + return tanh + leaky_relu + + partitioned = diamond.partition( + nested_diamond( + single_branch( + deeper_diamond( + classic_diamond(inp, weight), + weight), + weight), + weight + ) + ) + + functions = [] + for f in [classic_diamond, deeper_diamond, single_branch, nested_diamond]: + inpf = relay.var("input") + weightf = relay.var("weight") + functions.append(relay.Function([inpf, weightf], f(inpf, weightf))) + + reference = functions[3]( + functions[2]( + functions[1]( + functions[0](inp, weight), + weight), + weight), + weight + ) + assert tvm.ir.structural_equal(partitioned, reference) + +def get_BN(x, var, mean, beta, gamma, eps = 1e-5): + return gamma * (x - mean)/relay.op.sqrt(var + relay.const(eps)) + beta + +def test_parition_batchnorm(): + x = relay.var('x') + var = relay.var('var') + mean = relay.var('mean') + beta = relay.var('beta') + gamma = relay.var('gamma') + BN = get_BN(x, var, mean, beta, gamma) + + + xf = relay.var('xf') + varf = relay.var('varf') + meanf = relay.var('meanf') + betaf = relay.var('betaf') + gammaf = relay.var('gammaf') + # Put the arguments in toplogological order for the reference + f = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf)) + + partitioned = BatchnormCallback().pattern.partition(BN) + assert tvm.ir.structural_equal(partitioned, f(gamma, x, mean, var, beta)) + +def test_parition_double_batchnorm(): + x = relay.var('x') + var = relay.var('var') + mean = relay.var('mean') + beta = relay.var('beta') + gamma = relay.var('gamma') + + BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta + BN2 = gamma * (BN - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta + + xf = relay.var('xf') + varf = relay.var('varf') + meanf = relay.var('meanf') + betaf = relay.var('betaf') + gammaf = relay.var('gammaf') + f1 = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf)) + # The paritioner doesn't replace duplicates, so we use two copies of the function + xf2 = relay.var('xf2') + varf2 = relay.var('varf2') + meanf2 = relay.var('meanf2') + betaf2 = relay.var('betaf2') + gammaf2 = relay.var('gammaf2') + f2 = relay.Function([gammaf2, xf2, meanf2, varf2, betaf2], get_BN(xf2, varf2, meanf2, betaf2, gammaf2)) + + partitioned = BatchnormCallback().pattern.partition(BN2) + reference = f2(gamma, f1(gamma, x, mean, var, beta), mean, var, beta) + assert tvm.ir.structural_equal(partitioned, reference) + +if __name__ == "__main__": + test_match_op() + test_no_match_op() + test_match_op_or() + test_match_call() + test_no_match_call() + test_match_call_commutive() + test_no_match_call_commutive() + test_match_tuple() + test_no_match_tuple() + test_match_type() + test_no_match_type() + test_match_attr() + test_no_match_attr() + test_match_diamond() + test_no_match_diamond() + test_match_fake_diamond() + test_rewrite() + test_fuse_batchnorm() + test_no_fuse_batchnorm() + test_fuse_double_batchnorm() + test_partial_fuse_double_batchnorm() + test_fuse_batchnorm_commutation() + test_match_dominator() + test_not_match_dominator() + test_algebraic_simplify() + test_partition_dominator() + test_quadruple_partition_dominator() + test_parition_batchnorm() + test_parition_double_batchnorm() +