From c2b36154778503a509a70a3b5309b201969eccab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Mon, 22 Oct 2018 09:31:59 -0700 Subject: [PATCH] [Relay][Op]BroadcastToLike CollapseSumLike (#1886) --- docs/langref/relay_op.rst | 18 ++++++ python/tvm/relay/op/transform.py | 38 +++++++++++++ src/relay/op/tensor/transform.cc | 61 +++++++++++++++++++++ tests/python/relay/test_op_level10.py | 23 ++++++++ tests/python/relay/test_pass_alpha_equal.py | 1 + 5 files changed, 141 insertions(+) create mode 100644 tests/python/relay/test_op_level10.py diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index a36f8e6c71cf..6eba6b25d9fd 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -123,6 +123,17 @@ This level enables additional math and transform operators. tvm.relay.image.resize +**Level 10: Temporary Operators** + +This level support backpropagation of broadcast operators. It is temporary. + +.. autosummary:: + :nosignatures: + + tvm.relay.broadcast_to_like + tvm.relay.collapse_sum_like + + Level 1 Definitions ------------------- .. autofunction:: tvm.relay.log @@ -199,6 +210,13 @@ Level 4 Definitions .. autofunction:: tvm.relay.prod + Level 5 Definitions ------------------- .. autofunction:: tvm.relay.image.resize + + +Level 10 Definitions +-------------------- +.. autofunction:: tvm.relay.broadcast_to_like +.. autofunction:: tvm.relay.collapse_sum_like diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index c2036f509133..84e2398f0a9e 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -242,3 +242,41 @@ def where(condition, x, y): Note that the shape of condition, x, and y needs to be the same. """ return _make.where(condition, x, y) + + +def broadcast_to_like(data, broadcast_type): + """Return an scalar value array with the same shape and type as the input array. + + Parameters + ---------- + data : relay.Expr + The input tensor. + + broadcast_type : relay.Expr + Provide the type to broadcast to. + + Returns + ------- + result : relay.Expr + The resulting tensor. + """ + return _make.broadcast_to_like(data, broadcast_type) + + +def collapse_sum_like(data, collapse_type): + """Return an scalar value array with the same shape and type as the input array. + + Parameters + ---------- + data : relay.Expr + The input tensor. + + collapse_type : relay.Expr + Provide the type to collapse to. + + Returns + ------- + result : relay.Expr + The resulting tensor. + """ + return _make.collapse_sum_like(data, collapse_type) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 61ee2778d0a2..e3c8bcef217e 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -718,5 +718,66 @@ RELAY_REGISTER_OP("squeeze") .set_support_level(3) .add_type_rel("Squeeze", SqueezeRel); +// Have no idea how to assert the constraint. +// CollapseSumLike: -> B where BroadCast(A, B) = A +bool CollapseSumLikeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + reporter->Assign(types[2], types[1]); + return true; +} + +Expr MakeCollapseSumLike(Expr data, + Expr collapse_type) { + static const Op& op = Op::Get("collapse_sum_like"); + return CallNode::make(op, {data, collapse_type}, Attrs(), {}); +} + +TVM_REGISTER_API("relay.op._make.collapse_sum_like") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeCollapseSumLike, args, rv); + }); + +RELAY_REGISTER_OP("collapse_sum_like") +.describe(R"code(Collapse the first input to match the shape of the second input. +)code" TVM_ADD_FILELINE) +.set_num_inputs(2) +.add_argument("data", "Tensor", "The input tensor.") +.add_argument("collapse_type", "Tensor", "Provide the type to collapse to.") +.set_support_level(10) +.add_type_rel("CollapseSumLike", CollapseSumLikeRel); + +// BroadCastToLike: -> B where BroadCast(A, B) = B +bool BroadCastToLikeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + reporter->Assign(types[2], types[1]); + return true; +} + +Expr MakeBroadCastToLike(Expr data, + Expr broadcast_type) { + static const Op& op = Op::Get("broadcast_to_like"); + return CallNode::make(op, {data, broadcast_type}, Attrs(), {}); +} + +TVM_REGISTER_API("relay.op._make.broadcast_to_like") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeBroadCastToLike, args, rv); + }); + +RELAY_REGISTER_OP("broadcast_to_like") +.describe(R"code(Broadcast the first input to match the shape of the second input. +)code" TVM_ADD_FILELINE) +.set_num_inputs(2) +.add_argument("data", "Tensor", "The input tensor.") +.add_argument("broadcast_type", "Tensor", "Provide the type to broadcast to.") +.set_support_level(10) +.add_type_rel("BroadCastToLike", BroadCastToLikeRel); + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py new file mode 100644 index 000000000000..9486d029876d --- /dev/null +++ b/tests/python/relay/test_op_level10.py @@ -0,0 +1,23 @@ +""" Support level10 operator test cases. +""" +import tvm +from tvm import relay + +def test_collapse_sum_like(): + x = relay.Var("x", relay.ty.TensorType((3, 4, 5, 6), "int8")) + y = relay.Var("y", relay.ty.TensorType((4, 1, 6), "int8")) + z = relay.collapse_sum_like(x, y) + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.ty.TensorType((4, 1, 6), "int8") + + +def test_broadcast_to_like(): + x = relay.Var("x", relay.ty.TensorType((3, 4, 5, 6), "int8")) + y = relay.Var("y", relay.ty.TensorType((4, 1, 6), "int8")) + z = relay.broadcast_to_like(y, x) + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.ty.TensorType((3, 4, 5, 6), "int8") + +if __name__ == "__main__": + test_collapse_sum_like() + test_broadcast_to_like() diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index 7b27cb7ee2d4..de4df7c84b9f 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -461,3 +461,4 @@ def test_op_alpha_equal(): test_let_alpha_equal() test_if_alpha_equal() test_op_alpha_equal() + test_var_alpha_equal()