Skip to content

Commit

Permalink
enhance fusion for prarllel injectiveOD
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Mar 31, 2019
1 parent 3f72058 commit c9e69ee
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,9 @@ class GraphPartitioner {
// The final terminal node can already be fused to a OutEWiseFusable group.
auto fcond = [](OpPatternKind kind, bool is_sink) {
if (!is_sink) {
return kind <= kBroadcast;
// Elemwise, broadcast, and injective ops on the parallel branches
// are allowed be fused to the elemwise/broadcast master.
return kind <= kInjective;
} else {
return (kind <= kBroadcast ||
kind == kCommReduce ||
Expand Down
33 changes: 33 additions & 0 deletions tests/python/relay/test_pass_fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,38 @@ def expected(dim):
assert relay.ir_pass.alpha_equal(zz, after)


def test_fuse_parallel_injective():
"""Test fusing parallel injective ops to an elemwise op."""
def before():
x = relay.var("x", shape=(10, 20))
y = relay.add(x, relay.const(1, "float32"))
z = relay.squeeze(y)
u = relay.transpose(y, axes=[0, 1])
w = relay.left_shift(z, u)
return relay.Function([x], w)

def expected():
x = relay.var("p", shape=(10, 20))
y = relay.add(x, relay.const(1, "float32"))
z = relay.squeeze(y)
u = relay.transpose(y, axes=[0, 1])
w = relay.left_shift(z, u)
f1 = relay.Function([x], w)
x = relay.var("x", shape=(10, 20))
y = relay.Call(f1, [x])
return relay.Function([x], y)

z = before()
z = relay.ir_pass.infer_type(z)
zz = relay.ir_pass.fuse_ops(z, opt_level=0)
assert not relay.ir_pass.free_vars(zz)
zz = relay.ir_pass.fuse_ops(z, opt_level=2)
zz = relay.ir_pass.infer_type(zz)
assert not relay.ir_pass.free_vars(zz)
after = relay.ir_pass.infer_type(expected())
assert relay.ir_pass.alpha_equal(zz, after)


if __name__ == "__main__":
test_fuse_simple()
test_conv2d_fuse()
Expand All @@ -373,3 +405,4 @@ def expected(dim):
test_fuse_myia_regression()
test_fuse_tuple_get_elemwise()
test_tuple_get_root()
test_fuse_parallel_injective()

0 comments on commit c9e69ee

Please sign in to comment.