From 175edf0dceda9adf623b04852eb074eeb09fee98 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 31 Mar 2020 13:42:43 -0700 Subject: [PATCH] More Robust Diamond Matcher --- src/relay/ir/dataflow_matcher.cc | 13 +++++++++++-- tests/python/relay/test_df_pattern.py | 8 +++----- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 5e22e4b334ef2..d6340909da8e5 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -32,7 +32,6 @@ namespace relay { class DFPatternMatcher : public DFPatternFunctor { public: bool Match(const DFPattern& pattern, const Expr& expr); - protected: bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override; bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override; bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override; @@ -42,10 +41,20 @@ class DFPatternMatcher : public DFPatternFunctor memo_; }; bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { - return VisitDFPattern(pattern, expr); + if (memo_.count(pattern)) { + return expr.same_as(memo_[pattern]); + } else { + auto out = VisitDFPattern(pattern, expr); + if (out) { + memo_[pattern] = expr; + } + return out; + } } bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) { diff --git a/tests/python/relay/test_df_pattern.py b/tests/python/relay/test_df_pattern.py index 73d1d2f640dee..86a87a9581724 100644 --- a/tests/python/relay/test_df_pattern.py +++ b/tests/python/relay/test_df_pattern.py @@ -153,7 +153,7 @@ def test_no_match_attr(): def test_match_diamond(): # Pattern - is_conv2d = is_op('nn.conv2d')(is_input(), is_input()) + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) path1 = is_op('nn.relu')(is_conv2d) path2 = is_op('nn.leaky_relu')(is_conv2d) diamond = is_op('add')(path1, path2) @@ -171,7 +171,7 @@ def test_match_diamond(): def test_no_match_diamond(): # Pattern - is_conv2d = is_op('nn.conv2d')(is_input(), is_input()) + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) path1 = is_op('nn.relu')(is_conv2d) path2 = is_op('nn.leaky_relu')(is_conv2d) diamond = is_op('add')(path1, path2) @@ -190,9 +190,7 @@ def test_no_match_diamond(): def test_match_fake_diamond(): # Pattern - data_pat = is_input('data') - weight_pat = is_input('weight') - is_conv2d = is_op('nn.conv2d')(data_pat, weight_pat) + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) path1 = is_op('nn.relu')(is_conv2d) path2 = is_op('nn.leaky_relu')(is_conv2d) diamond = is_op('add')(path1, path2)