Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay, TOPI] Complete rewrite of where op to support broadcasting #6759

Merged
merged 13 commits into from
Oct 28, 2020
67 changes: 23 additions & 44 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
#include <unordered_set>
#include <vector>

#include "detail/broadcast.h"

namespace tvm {
namespace topi {

Expand Down Expand Up @@ -887,53 +889,30 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int axis, std::string
*/
inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y,
std::string name = "T_where", std::string tag = kBroadcast) {
ICHECK_EQ(x->shape.size(), y->shape.size())
<< "x and y must have the same shape.Got different number of dimension: " << x->shape.size()
<< " vs " << y->shape.size();
ICHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs "
<< y->dtype;
auto get_out_shape = [&]() {
auto bh1 = detail::BroadcastShape(x->shape, y->shape);
Array<PrimExpr> common_shape1(bh1.common_shape.begin(), bh1.common_shape.end());
auto bh2 = detail::BroadcastShape(condition->shape, common_shape1);
Array<PrimExpr> common_shape2(bh2.common_shape.begin(), bh2.common_shape.end());
return common_shape2;
};

if (x->shape.size() == 0) {
return compute(
condition->shape,
[&](const Array<Var>& indices) {
PrimExpr cond;
if (condition->shape.size() == 0) {
cond = condition();
} else {
Array<PrimExpr> condition_idx{indices[0]};
cond = condition(condition_idx);
}
return tvm::tir::Select(cond != 0, x(), y());
},
name, tag);
} else if (condition->shape.size() != 1) {
ICHECK_EQ(condition->shape.size(), x->shape.size())
<< "condition array must be either have the same shape as x or to be a "
"1-D array.Got different number of dimension: "
<< condition->shape.size() << " vs " << x->shape.size();
return compute(
x->shape,
[&](const Array<Var>& indices) {
return tvm::tir::Select(condition(indices) != 0, x(indices), y(indices));
},
name, tag);
} else {
int64_t cond_first_dim = topi::GetConstInt(condition->shape[0]);
int64_t x_first_dim = topi::GetConstInt(x->shape[0]);
if (cond_first_dim > 0 && x_first_dim > 0) {
ICHECK_EQ(cond_first_dim, x_first_dim)
<< "If condition is 1-D, the first dimension must be the same as x: " << cond_first_dim
<< " vs " << x_first_dim;
}
return compute(
x->shape,
[&](const Array<Var>& indices) {
Array<PrimExpr> condition_idx{indices[0]};
return tvm::tir::Select(condition(condition_idx) != 0, x(indices), y(indices));
},
name, tag);
}
auto oshape = get_out_shape();

auto c_bh = detail::BroadcastShape(condition->shape, oshape);
auto x_bh = detail::BroadcastShape(x->shape, oshape);
auto y_bh = detail::BroadcastShape(y->shape, oshape);

auto select = [&](tvm::Array<tvm::tir::Var> ovars) {
auto c = condition(InputIndexFromBroadcast(ovars, condition, c_bh.vars1, c_bh.all_vars));
auto true_val = x(InputIndexFromBroadcast(ovars, x, x_bh.vars1, x_bh.all_vars));
auto false_val = y(InputIndexFromBroadcast(ovars, y, y_bh.vars1, y_bh.all_vars));
return tvm::tir::Select(c != 0, true_val, false_val);
};

return compute(oshape, select, name, tag);
}

/*!
Expand Down
34 changes: 32 additions & 2 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,13 +810,43 @@ def stack_shape_func(attrs, inputs, _):
return [_stack_shape_func(inputs[0], convert(axis), convert(len(inputs)))]


@script
def _broadcast_shape_tensors(shape_tensor1, shape_tensor2):
rank1 = shape_tensor1.shape[0]
rank2 = shape_tensor2.shape[0]
out_rank = max(rank1, rank2)
bcast_shape_tensor = output_tensor((out_rank,), "int64")

for index in const_range(out_rank):
dim1 = int64(1)
dim2 = int64(1)

if rank1 == out_rank:
dim1 = shape_tensor1[index]
elif rank1 - (out_rank - index) >= 0:
dim1 = shape_tensor1[rank1 - (out_rank - index)]

if rank2 == out_rank:
dim2 = shape_tensor2[index]
elif rank2 - (out_rank - index) >= 0:
dim2 = shape_tensor2[rank2 - (out_rank - index)]

assert dim1 == dim2 or dim1 == 1 or dim2 == 1, "Invalid broadcast shapes"
bcast_shape_tensor[index] = max(dim1, dim2)

return bcast_shape_tensor


@_reg.register_shape_func("where", False)
def where_shape_func(attrs, inputs, _):
"""
Shape func for where.
"""
cond_shape = inputs[0]
x_shape = inputs[1]
out_shape = x_shape if x_shape.shape else cond_shape
y_shape = inputs[2]

bcast_shape = _broadcast_shape_tensors(x_shape, y_shape)
out_shape = _broadcast_shape_tensors(bcast_shape, cond_shape)

return [topi.math.identity(out_shape)]
return [out_shape]
17 changes: 9 additions & 8 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,25 +649,26 @@ def where(condition, x, y):
condition.

.. note::
The shape of condition, x, and y needs to be the same.
Shapes of condition, x, and y must be broadcastable to a common shape.
Semantics follow numpy where function
https://numpy.org/doc/stable/reference/generated/numpy.where.html

Parameters
----------
condition : relay.Expr
The condition array. The n-th element in `y` is selected when the n-th
value in the `condition` array is zero. Otherwise, the corresponding
element from `x` will be picked.
Where True, yield x, otherwise yield y

x : relay.Expr
The first array to be selected.
The first array or scalar to be selected.

y : relay.Expr
The second array to be selected.
The second array or scalar to be selected.

Returns
-------
result : relay.Expr
The selected array.
The selected array. The output shape is the broadcasted shape from
condition, x, and y.

Examples
--------
Expand All @@ -678,7 +679,7 @@ def where(condition, x, y):
condition = [[0, 1], [-1, 0]]
relay.where(conditon, x, y) = [[5, 2], [3, 8]]

condition = [1, 0]
condition = [[1], [0]]
relay.where(conditon, x, y) = [[1, 2], [7, 8]]
"""
return _make.where(condition, x, y)
Expand Down
47 changes: 14 additions & 33 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "../../transforms/pattern_utils.h"
#include "../make_op.h"
#include "../op_common.h"
#include "../type_relations.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -1685,30 +1686,17 @@ bool WhereRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
return false;
}

const auto& cond_shape = condition->shape;
const auto& x_shape = x->shape;
const auto& y_shape = y->shape;
ICHECK(x_shape.size() == y_shape.size()) << "x and y must have the same size";
ICHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs "
<< y->dtype;

if (cond_shape.size() != x_shape.size()) {
ICHECK_EQ(cond_shape.size(), 1) << "Shape of condition " << condition->shape
<< " must be either equal to x or has dimension of 1.";
}
for (size_t i = 0; i < x_shape.size(); i++) {
ICHECK(reporter->AssertEQ(x_shape[i], y_shape[i]))
<< "x and y must have the same shape: " << x_shape << " vs " << y_shape;
auto tensor_ty_condition = GetRef<TensorType>(condition);
auto tensor_ty_x = GetRef<TensorType>(x);
auto tensor_ty_y = GetRef<TensorType>(y);

if (i < cond_shape.size()) {
ICHECK(reporter->AssertEQ(cond_shape[i], x_shape[i]))
<< "condition and x must have the same shape: " << cond_shape << " vs " << x_shape;
}
}
if (x_shape.size() == 0) {
// if x and y are scalar, the condition shape becomes the output shape
reporter->Assign(types[3], TensorType(cond_shape, x->dtype));
} else {
reporter->Assign(types[3], TensorType(x_shape, x->dtype));
}
auto b_ty = ConcreteBroadcast(tensor_ty_x, tensor_ty_y, x->dtype);
auto ret_ty = ConcreteBroadcast(tensor_ty_condition, b_ty, b_ty->dtype);
masahi marked this conversation as resolved.
Show resolved Hide resolved

reporter->Assign(types[3], ret_ty);
return true;
}

Expand All @@ -1731,17 +1719,10 @@ Return the elements, either from x or y, depending on the condition.

Given three ndarrays, condition, x, and y, return an ndarray with the elements
from x or y, depending on the elements from condition are true or false.
x and y must have the same shape. If condition has the same shape as x,
each element in the output array is from x if the corresponding element
in the condition is true, and from y if false.

If condition does not have the same shape as x, it must be a 1D array whose
size is the same as x’s first dimension size. Each row of the output array
is from x’s row if the corresponding element from condition is true, and
from y’s row if false.

When x and y are scalars, condition must be an 1D array. The output shape
is the same as condition's shape.
Shapes of condition, x, and y must be broadcastable to a common shape, which
is the output shape of this op. Semantics follow numpy where function.
https://numpy.org/doc/stable/reference/generated/numpy.where.html

Note that all non-zero values are interpreted as True in condition.

Expand All @@ -1753,7 +1734,7 @@ Examples::
where(cond, x, y) = [[5, 2], [3, 8]]


cond = [1, 0]
cond = [[1], [0]]
where(cond, x, y) = [[1, 2], [7, 8]]

cond = [0, 1]
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/type_relations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ bool EqualConstInt(const IndexExpr& lhs, int64_t value) {
return false;
}

Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype) {
TensorType ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype) {
std::vector<IndexExpr> oshape;
size_t ndim1 = t1->shape.size();
size_t ndim2 = t2->shape.size();
Expand Down
9 changes: 9 additions & 0 deletions src/relay/op/type_relations.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ bool IdentityRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
bool BroadcastRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter);

/*!
* \brief Determine the broadcasted shape from two input shapes
* \param t1 One of two Tensortype whose shapes are broadcasted
* \param t2 One of two Tensortype whose shapes are broadcasted
* \param output_dtype dtype of the output TensorType
* \return A TensorType whose shape is broadcasted from two input TensorType.
*/
TensorType ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype);

/*!
* \brief The broadcast type relation, implements the broadcasting
* rule over the two input types producing the broadcasted type.
Expand Down
2 changes: 2 additions & 0 deletions tests/python/frontend/pytorch/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ def test_custom_lstm():
num_layers = 3
state_tensor_shape = (batch, hidden_size)

torch.manual_seed(1)

inp = torch.randn(seq_len, batch, input_size)

input_shapes = [
Expand Down
44 changes: 44 additions & 0 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,5 +1236,49 @@ def test_any_stack():
verify_any_stack(any_dims(4), (2, 1, 1, 4), 2, 2)


def verify_any_where(
cond_shape, x_shape, y_shape, cond_np_shape, x_np_shape, y_np_shape, y_np_shape_invalid=None
):
dtype = "float32"
cond = relay.var("cond", shape=cond_shape, dtype="bool")
x = relay.var("x", shape=x_shape, dtype=dtype)
y = relay.var("y", shape=y_shape, dtype=dtype)
z = relay.where(cond, x, y)
mod = tvm.IRModule()
mod["main"] = relay.Function([cond, x, y], z)

cond_np = np.random.randn(*cond_np_shape) > 0
x_np = np.random.randn(*x_np_shape).astype(dtype)
y_np = np.random.randn(*y_np_shape).astype(dtype)
expected = np.where(cond_np, x_np, y_np)

check_result([cond_np, x_np, y_np], mod, expected)

# verify invalid broadcasting check
if y_np_shape_invalid:
y_np_bad = np.random.randn(*y_np_shape_invalid).astype(dtype)
try:
check_result([cond_np, x_np, y_np_bad], mod, expected)
except tvm.error.TVMError as e:
error_msg = str(e).split("\n")[-1]
assert "Invalid broadcast shapes" in error_msg


@tvm.testing.uses_gpu
def test_any_where():
verify_any_where(any_dims(1), (5,), (5,), (5,), (5,), (5,))
verify_any_where(any_dims(1), any_dims(1), (5,), (5,), (5,), (5,))
verify_any_where(any_dims(1), any_dims(1), any_dims(1), (5,), (5,), (5,))
verify_any_where((5,), any_dims(1), any_dims(1), (5,), (5,), (5,))

# where with broadcast
verify_any_where(any_dims(1), any_dims(1), any_dims(1), (5,), (1,), (5,))
verify_any_where(any_dims(1), any_dims(2), any_dims(2), (5,), (5, 5), (5, 5))
verify_any_where(any_dims(1), any_dims(1), any_dims(2), (5,), (5,), (5, 5))
verify_any_where(
any_dims(2), any_dims(2), any_dims(2), (3, 4), (3, 1), (1, 4), y_np_shape_invalid=(2, 4)
)


if __name__ == "__main__":
pytest.main([__file__])
Loading