Skip to content

Commit

Permalink
[Relay][Op]BroadcastToLike CollapseSumLike (apache#1886)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarisaKirisame authored and tqchen committed Oct 22, 2018
1 parent c51268c commit c2b3615
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 0 deletions.
18 changes: 18 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,17 @@ This level enables additional math and transform operators.
tvm.relay.image.resize


**Level 10: Temporary Operators**

This level support backpropagation of broadcast operators. It is temporary.

.. autosummary::
:nosignatures:

tvm.relay.broadcast_to_like
tvm.relay.collapse_sum_like


Level 1 Definitions
-------------------
.. autofunction:: tvm.relay.log
Expand Down Expand Up @@ -199,6 +210,13 @@ Level 4 Definitions
.. autofunction:: tvm.relay.prod



Level 5 Definitions
-------------------
.. autofunction:: tvm.relay.image.resize


Level 10 Definitions
--------------------
.. autofunction:: tvm.relay.broadcast_to_like
.. autofunction:: tvm.relay.collapse_sum_like
38 changes: 38 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,41 @@ def where(condition, x, y):
Note that the shape of condition, x, and y needs to be the same.
"""
return _make.where(condition, x, y)


def broadcast_to_like(data, broadcast_type):
"""Return an scalar value array with the same shape and type as the input array.
Parameters
----------
data : relay.Expr
The input tensor.
broadcast_type : relay.Expr
Provide the type to broadcast to.
Returns
-------
result : relay.Expr
The resulting tensor.
"""
return _make.broadcast_to_like(data, broadcast_type)


def collapse_sum_like(data, collapse_type):
"""Return an scalar value array with the same shape and type as the input array.
Parameters
----------
data : relay.Expr
The input tensor.
collapse_type : relay.Expr
Provide the type to collapse to.
Returns
-------
result : relay.Expr
The resulting tensor.
"""
return _make.collapse_sum_like(data, collapse_type)
61 changes: 61 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -718,5 +718,66 @@ RELAY_REGISTER_OP("squeeze")
.set_support_level(3)
.add_type_rel("Squeeze", SqueezeRel);

// Have no idea how to assert the constraint.
// CollapseSumLike: <A, B> -> B where BroadCast(A, B) = A
bool CollapseSumLikeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
reporter->Assign(types[2], types[1]);
return true;
}

Expr MakeCollapseSumLike(Expr data,
Expr collapse_type) {
static const Op& op = Op::Get("collapse_sum_like");
return CallNode::make(op, {data, collapse_type}, Attrs(), {});
}

TVM_REGISTER_API("relay.op._make.collapse_sum_like")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeCollapseSumLike, args, rv);
});

RELAY_REGISTER_OP("collapse_sum_like")
.describe(R"code(Collapse the first input to match the shape of the second input.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("collapse_type", "Tensor", "Provide the type to collapse to.")
.set_support_level(10)
.add_type_rel("CollapseSumLike", CollapseSumLikeRel);

// BroadCastToLike: <A, B> -> B where BroadCast(A, B) = B
bool BroadCastToLikeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
reporter->Assign(types[2], types[1]);
return true;
}

Expr MakeBroadCastToLike(Expr data,
Expr broadcast_type) {
static const Op& op = Op::Get("broadcast_to_like");
return CallNode::make(op, {data, broadcast_type}, Attrs(), {});
}

TVM_REGISTER_API("relay.op._make.broadcast_to_like")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeBroadCastToLike, args, rv);
});

RELAY_REGISTER_OP("broadcast_to_like")
.describe(R"code(Broadcast the first input to match the shape of the second input.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("broadcast_type", "Tensor", "Provide the type to broadcast to.")
.set_support_level(10)
.add_type_rel("BroadCastToLike", BroadCastToLikeRel);

} // namespace relay
} // namespace tvm
23 changes: 23 additions & 0 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
""" Support level10 operator test cases.
"""
import tvm
from tvm import relay

def test_collapse_sum_like():
x = relay.Var("x", relay.ty.TensorType((3, 4, 5, 6), "int8"))
y = relay.Var("y", relay.ty.TensorType((4, 1, 6), "int8"))
z = relay.collapse_sum_like(x, y)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType((4, 1, 6), "int8")


def test_broadcast_to_like():
x = relay.Var("x", relay.ty.TensorType((3, 4, 5, 6), "int8"))
y = relay.Var("y", relay.ty.TensorType((4, 1, 6), "int8"))
z = relay.broadcast_to_like(y, x)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType((3, 4, 5, 6), "int8")

if __name__ == "__main__":
test_collapse_sum_like()
test_broadcast_to_like()
1 change: 1 addition & 0 deletions tests/python/relay/test_pass_alpha_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,3 +461,4 @@ def test_op_alpha_equal():
test_let_alpha_equal()
test_if_alpha_equal()
test_op_alpha_equal()
test_var_alpha_equal()

0 comments on commit c2b3615

Please sign in to comment.