Skip to content

Commit

Permalink
More Robust Diamond Matcher
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart committed Mar 31, 2020
1 parent b29dead commit 175edf0
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
13 changes: 11 additions & 2 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ namespace relay {
class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
public:
bool Match(const DFPattern& pattern, const Expr& expr);
protected:
bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
Expand All @@ -42,10 +41,20 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Ex
bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;
protected:
std::unordered_map<DFPattern, Expr, ObjectHash, ObjectEqual> memo_;
};

bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
return VisitDFPattern(pattern, expr);
if (memo_.count(pattern)) {
return expr.same_as(memo_[pattern]);
} else {
auto out = VisitDFPattern(pattern, expr);
if (out) {
memo_[pattern] = expr;
}
return out;
}
}

bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
Expand Down
8 changes: 3 additions & 5 deletions tests/python/relay/test_df_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def test_no_match_attr():

def test_match_diamond():
# Pattern
is_conv2d = is_op('nn.conv2d')(is_input(), is_input())
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
path1 = is_op('nn.relu')(is_conv2d)
path2 = is_op('nn.leaky_relu')(is_conv2d)
diamond = is_op('add')(path1, path2)
Expand All @@ -171,7 +171,7 @@ def test_match_diamond():

def test_no_match_diamond():
# Pattern
is_conv2d = is_op('nn.conv2d')(is_input(), is_input())
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
path1 = is_op('nn.relu')(is_conv2d)
path2 = is_op('nn.leaky_relu')(is_conv2d)
diamond = is_op('add')(path1, path2)
Expand All @@ -190,9 +190,7 @@ def test_no_match_diamond():

def test_match_fake_diamond():
# Pattern
data_pat = is_input('data')
weight_pat = is_input('weight')
is_conv2d = is_op('nn.conv2d')(data_pat, weight_pat)
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
path1 = is_op('nn.relu')(is_conv2d)
path2 = is_op('nn.leaky_relu')(is_conv2d)
diamond = is_op('add')(path1, path2)
Expand Down

0 comments on commit 175edf0

Please sign in to comment.