Skip to content

Commit

Permalink
fix pattern topological order (apache#5612)
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart authored and trevor-m committed Jun 18, 2020
1 parent a4fb426 commit edc1655
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 5 deletions.
10 changes: 6 additions & 4 deletions src/relay/ir/indexed_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,12 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const DFPattern& pattern) {

protected:
void VisitDFPattern(const DFPattern& pattern) override {
DFPatternVisitor::VisitDFPattern(pattern);
auto node = std::make_shared<IndexedGraph<DFPattern>::Node>(pattern, index_++);
graph_.node_map_[pattern] = node;
graph_.topological_order_.push_back(node);
if (this->visited_.count(pattern.get()) == 0) {
DFPatternVisitor::VisitDFPattern(pattern);
auto node = std::make_shared<IndexedGraph<DFPattern>::Node>(pattern, index_++);
graph_.node_map_[pattern] = node;
graph_.topological_order_.push_back(node);
}
}
IndexedGraph<DFPattern> graph_;
size_t index_ = 0;
Expand Down
4 changes: 3 additions & 1 deletion src/relay/ir/indexed_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class IndexedGraph {
std::vector<Node*> outputs_;

/*! \brief The depth of the node in the dominator tree */
size_t depth_;
size_t depth_ = 0;
/*! \brief The dominator parent/final user of the outputs of this node */
Node* dominator_parent_;
/*! \brief The nodes this node dominates */
Expand Down Expand Up @@ -115,6 +115,8 @@ class IndexedGraph {
return nullptr;
}
while (lhs != rhs) {
CHECK(lhs);
CHECK(rhs);
if (lhs->depth_ < rhs->depth_) {
rhs = rhs->dominator_parent_;
} else if (lhs->depth_ > rhs->depth_) {
Expand Down
30 changes: 30 additions & 0 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,35 @@ def callback(self, pre, post, node_map):
out = rewrite(TestRewrite(), x + y)
assert sub_pattern.match(out)

def test_nested_rewrite():
class PatternCallback(DFPatternCallback):
def __init__(self, pattern):
self.pattern = pattern

def callback(self, pre, post, node_map):
return post

def gen():
x = relay.var('x')
y = relay.var('y')
y_add = relay.add(y, y)
n0 = relay.add(x, y_add)
n1 = relay.add(x, n0)
return relay.add(n1, n0)

def pattern():
a = wildcard()
b = wildcard()
n0 = is_op('add')(a, b)
n1 = is_op('add')(n0, a)
return is_op('add')(n0, n1)

out = gen()
pat = pattern()
new_out = rewrite(PatternCallback(pat), out)

assert tvm.ir.structural_equal(out, new_out)

def test_not_fuse_multi_diamond():
# Pattern
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
Expand Down Expand Up @@ -838,6 +867,7 @@ def test_parition_double_batchnorm():
test_no_match_diamond()
test_match_fake_diamond()
test_rewrite()
test_nested_rewrite()
test_fuse_batchnorm()
test_no_fuse_batchnorm()
test_fuse_double_batchnorm()
Expand Down

0 comments on commit edc1655

Please sign in to comment.