Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay, TOPI] Complete rewrite of where op to support broadcasting #6759

Merged
merged 13 commits into from
Oct 28, 2020
36 changes: 0 additions & 36 deletions include/tvm/topi/broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,42 +69,6 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t,
return tvm::te::compute(oshape, l, name, tag);
}

// This is used in the shape func of where op
inline tvm::te::Tensor broadcast_shape_tensors(const tvm::te::Tensor& shape_tensor1,
const tvm::te::Tensor& shape_tensor2,
std::string name = "T_broadcast_shape_tensors",
std::string tag = kBroadcast) {
const auto rank1 = detail::GetConstInt(shape_tensor1->shape[0]);
const auto rank2 = detail::GetConstInt(shape_tensor2->shape[0]);
const auto out_rank = std::max<int32_t>(rank1, rank2);
const tvm::PrimExpr one = tvm::cast(shape_tensor1->dtype, PrimExpr(1));

auto select_dim = [&](const tvm::te::Tensor& shape_tensor, int rank,
tvm::tir::Var index) -> PrimExpr {
if (rank < out_rank) {
// if the rank is smaller, dimension 1 is prepended according to
// the numpy broadcasting semantics.
return tvm::tir::Select(rank - (out_rank - index) < 0, one,
shape_tensor[rank - (out_rank - index)]);
} else {
// rank == out_rank, safe to index directly
return shape_tensor[index];
}
};

auto func = [&](tvm::Array<tvm::tir::Var> ovars) {
auto index = ovars[0];
PrimExpr dim1 = select_dim(shape_tensor1, rank1, index);
PrimExpr dim2 = select_dim(shape_tensor2, rank2, index);
return tvm::max(dim1, dim2);
};

Array<PrimExpr> oshape;
oshape.push_back(PrimExpr(out_rank));

return tvm::te::compute(oshape, func, name, tag);
}

#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \
inline tvm::PrimExpr Name(const tvm::PrimExpr& a, const tvm::PrimExpr& b) { ComputeRule; } \
inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B, \
Expand Down
32 changes: 30 additions & 2 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,33 @@ def stack_shape_func(attrs, inputs, _):
return [_stack_shape_func(inputs[0], convert(axis), convert(len(inputs)))]


@script
def _broadcast_shape_tensors(shape_tensor1, shape_tensor2):
rank1 = shape_tensor1.shape[0]
rank2 = shape_tensor2.shape[0]
out_rank = max(rank1, rank2)
bcast_shape_tensor = output_tensor((out_rank,), "int64")

for index in const_range(out_rank):
dim1 = int64(1)
dim2 = int64(1)

if rank1 == out_rank:
dim1 = shape_tensor1[index]
elif rank1 - (out_rank - index) >= 0:
dim1 = shape_tensor1[rank1 - (out_rank - index)]

if rank2 == out_rank:
dim2 = shape_tensor2[index]
elif rank2 - (out_rank - index) >= 0:
dim2 = shape_tensor2[rank2 - (out_rank - index)]

assert dim1 == dim2 or dim1 == 1 or dim2 == 1, "Invalid broadcast shapes"
bcast_shape_tensor[index] = max(dim1, dim2)

return bcast_shape_tensor


@_reg.register_shape_func("where", False)
def where_shape_func(attrs, inputs, _):
"""
Expand All @@ -818,7 +845,8 @@ def where_shape_func(attrs, inputs, _):
cond_shape = inputs[0]
x_shape = inputs[1]
y_shape = inputs[2]
bcast_shape = topi.broadcast.broadcast_shape_tensors(x_shape, y_shape)
out_shape = topi.broadcast.broadcast_shape_tensors(bcast_shape, cond_shape)

bcast_shape = _broadcast_shape_tensors(x_shape, y_shape)
out_shape = _broadcast_shape_tensors(bcast_shape, cond_shape)

return [out_shape]
22 changes: 1 addition & 21 deletions python/tvm/topi/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
def broadcast_to(data, shape):
"""Broadcast the src to the target shape

We follow the numpy broadcasting rule.
We follows the numpy broadcasting rule.
masahi marked this conversation as resolved.
Show resolved Hide resolved
See also https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html

Parameters
Expand All @@ -40,26 +40,6 @@ def broadcast_to(data, shape):
return _cpp.broadcast_to(data, shape)


def broadcast_shape_tensors(shape_tensor1, shape_tensor2):
"""Compute a shape tensor whose values represents the broadcasted shape
of two input shape tensors

Parameters
----------
shape_tensor1 : tvm.te.Tensor
One of input shape tensors

shape_tensor2 : tvm.te.Tensor
One of input shape tensors

Returns
-------
ret : tvm.te.Tensor
A shape tensor whose values represents the broadcasted shape
"""
return _cpp.broadcast_shape_tensors(shape_tensor1, shape_tensor2)


def add(lhs, rhs):
"""Addition with auto-broadcasting

Expand Down
4 changes: 0 additions & 4 deletions src/topi/broadcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,5 @@ TVM_REGISTER_GLOBAL("topi.broadcast_to").set_body([](TVMArgs args, TVMRetValue*
*rv = broadcast_to(args[0], args[1]);
});

TVM_REGISTER_GLOBAL("topi.broadcast_shape_tensors").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = broadcast_shape_tensors(args[0], args[1]);
});

} // namespace topi
} // namespace tvm
17 changes: 15 additions & 2 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,9 @@ def test_any_stack():
verify_any_stack(any_dims(4), (2, 1, 1, 4), 2, 2)


def verify_any_where(cond_shape, x_shape, y_shape, cond_np_shape, x_np_shape, y_np_shape):
def verify_any_where(
cond_shape, x_shape, y_shape, cond_np_shape, x_np_shape, y_np_shape, y_np_shape_invalid=None
):
dtype = "float32"
cond = relay.var("cond", shape=cond_shape, dtype="bool")
x = relay.var("x", shape=x_shape, dtype=dtype)
Expand All @@ -1252,6 +1254,15 @@ def verify_any_where(cond_shape, x_shape, y_shape, cond_np_shape, x_np_shape, y_

check_result([cond_np, x_np, y_np], mod, expected)

# verify invalid broadcasting check
if y_np_shape_invalid:
y_np_bad = np.random.randn(*y_np_shape_invalid).astype(dtype)
try:
check_result([cond_np, x_np, y_np_bad], mod, expected)
except tvm.error.TVMError as e:
error_msg = str(e).split("\n")[-1]
assert "Invalid broadcast shapes" in error_msg


@tvm.testing.uses_gpu
def test_any_where():
Expand All @@ -1264,7 +1275,9 @@ def test_any_where():
verify_any_where(any_dims(1), any_dims(1), any_dims(1), (5,), (1,), (5,))
verify_any_where(any_dims(1), any_dims(2), any_dims(2), (5,), (5, 5), (5, 5))
verify_any_where(any_dims(1), any_dims(1), any_dims(2), (5,), (5,), (5, 5))
verify_any_where(any_dims(2), any_dims(2), any_dims(2), (3, 4), (3, 1), (1, 4))
verify_any_where(
any_dims(2), any_dims(2), any_dims(2), (3, 4), (3, 1), (1, 4), y_np_shape_invalid=(2, 4)
)


if __name__ == "__main__":
Expand Down