Skip to content

Commit

Permalink
Check function attr for alpha equal (#4479)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics authored and tqchen committed Dec 8, 2019
1 parent 03a59bc commit 5fe5cee
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/relay/ir/hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ class RelayHashHandler:
hash = Combine(hash, TypeHash(func->ret_type));
hash = Combine(hash, ExprHash(func->body));

hash = Combine(hash, AttrHash(func->attrs));

return hash;
}

Expand Down
24 changes: 24 additions & 0 deletions tests/python/relay/test_pass_alpha_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,29 @@ def test_tuple_get_item_alpha_equal():
assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1))


def test_multi_node_subgraph():
x0 = relay.var('x0', shape=(10, 10))
w00 = relay.var('w00', shape=(10, 10))
w01 = relay.var('w01', shape=(10, 10))
w02 = relay.var('w02', shape=(10, 10))
z00 = relay.add(x0, w00)
p00 = relay.subtract(z00, w01)
q00 = relay.multiply(p00, w02)
func0 = relay.Function([x0, w00, w01, w02], q00)
func0 = func0.set_attribute("FuncName", tvm.expr.StringImm("a"))

x1 = relay.var('x1', shape=(10, 10))
w10 = relay.var('w10', shape=(10, 10))
w11 = relay.var('w11', shape=(10, 10))
w12 = relay.var('w12', shape=(10, 10))
z10 = relay.add(x1, w10)
p10 = relay.subtract(z10, w11)
q10 = relay.multiply(p10, w12)
func1 = relay.Function([x1, w10, w11, w12], q10)
func1 = func1.set_attribute("FuncName", tvm.expr.StringImm("b"))
assert not alpha_equal(func0, func1)


def test_function_alpha_equal():
tt1 = relay.TensorType((1, 2, 3), "float32")
tt2 = relay.TensorType((4, 5, 6), "int8")
Expand Down Expand Up @@ -639,6 +662,7 @@ def test_tuple_match():
test_tuple_alpha_equal()
test_tuple_get_item_alpha_equal()
test_function_alpha_equal()
test_function_attr()
test_call_alpha_equal()
test_let_alpha_equal()
test_if_alpha_equal()
Expand Down

0 comments on commit 5fe5cee

Please sign in to comment.