Skip to content

Commit

Permalink
[Relay, TOPI] Complete rewrite of where op to support broadcasting (a…
Browse files Browse the repository at this point in the history
…pache#6759)

* where type rel with broadcast

* add tests for where with broadcast

* clean up tests

* uncomment other tests

* add more tests

* update doc

* CHECK -> ICHECK

* add where any test

* fix format

* remove useless detections for one

* set manual seed

* ported shape broadcast helper func to hybridscript

* remove shape function helper from cpp

Co-authored-by: masa <[email protected]>
  • Loading branch information
2 people authored and Trevor Morris committed Oct 28, 2020
1 parent 8779bf4 commit 5b31545
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 111 deletions.
67 changes: 23 additions & 44 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
#include <unordered_set>
#include <vector>

#include "detail/broadcast.h"

namespace tvm {
namespace topi {

Expand Down Expand Up @@ -887,53 +889,30 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int axis, std::string
*/
inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y,
std::string name = "T_where", std::string tag = kBroadcast) {
ICHECK_EQ(x->shape.size(), y->shape.size())
<< "x and y must have the same shape.Got different number of dimension: " << x->shape.size()
<< " vs " << y->shape.size();
ICHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs "
<< y->dtype;
auto get_out_shape = [&]() {
auto bh1 = detail::BroadcastShape(x->shape, y->shape);
Array<PrimExpr> common_shape1(bh1.common_shape.begin(), bh1.common_shape.end());
auto bh2 = detail::BroadcastShape(condition->shape, common_shape1);
Array<PrimExpr> common_shape2(bh2.common_shape.begin(), bh2.common_shape.end());
return common_shape2;
};

if (x->shape.size() == 0) {
return compute(
condition->shape,
[&](const Array<Var>& indices) {
PrimExpr cond;
if (condition->shape.size() == 0) {
cond = condition();
} else {
Array<PrimExpr> condition_idx{indices[0]};
cond = condition(condition_idx);
}
return tvm::tir::Select(cond != 0, x(), y());
},
name, tag);
} else if (condition->shape.size() != 1) {
ICHECK_EQ(condition->shape.size(), x->shape.size())
<< "condition array must be either have the same shape as x or to be a "
"1-D array.Got different number of dimension: "
<< condition->shape.size() << " vs " << x->shape.size();
return compute(
x->shape,
[&](const Array<Var>& indices) {
return tvm::tir::Select(condition(indices) != 0, x(indices), y(indices));
},
name, tag);
} else {
int64_t cond_first_dim = topi::GetConstInt(condition->shape[0]);
int64_t x_first_dim = topi::GetConstInt(x->shape[0]);
if (cond_first_dim > 0 && x_first_dim > 0) {
ICHECK_EQ(cond_first_dim, x_first_dim)
<< "If condition is 1-D, the first dimension must be the same as x: " << cond_first_dim
<< " vs " << x_first_dim;
}
return compute(
x->shape,
[&](const Array<Var>& indices) {
Array<PrimExpr> condition_idx{indices[0]};
return tvm::tir::Select(condition(condition_idx) != 0, x(indices), y(indices));
},
name, tag);
}
auto oshape = get_out_shape();

auto c_bh = detail::BroadcastShape(condition->shape, oshape);
auto x_bh = detail::BroadcastShape(x->shape, oshape);
auto y_bh = detail::BroadcastShape(y->shape, oshape);

auto select = [&](tvm::Array<tvm::tir::Var> ovars) {
auto c = condition(InputIndexFromBroadcast(ovars, condition, c_bh.vars1, c_bh.all_vars));
auto true_val = x(InputIndexFromBroadcast(ovars, x, x_bh.vars1, x_bh.all_vars));
auto false_val = y(InputIndexFromBroadcast(ovars, y, y_bh.vars1, y_bh.all_vars));
return tvm::tir::Select(c != 0, true_val, false_val);
};

return compute(oshape, select, name, tag);
}

/*!
Expand Down
34 changes: 32 additions & 2 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,13 +810,43 @@ def stack_shape_func(attrs, inputs, _):
return [_stack_shape_func(inputs[0], convert(axis), convert(len(inputs)))]


@script
def _broadcast_shape_tensors(shape_tensor1, shape_tensor2):
rank1 = shape_tensor1.shape[0]
rank2 = shape_tensor2.shape[0]
out_rank = max(rank1, rank2)
bcast_shape_tensor = output_tensor((out_rank,), "int64")

for index in const_range(out_rank):
dim1 = int64(1)
dim2 = int64(1)

if rank1 == out_rank:
dim1 = shape_tensor1[index]
elif rank1 - (out_rank - index) >= 0:
dim1 = shape_tensor1[rank1 - (out_rank - index)]

if rank2 == out_rank:
dim2 = shape_tensor2[index]
elif rank2 - (out_rank - index) >= 0:
dim2 = shape_tensor2[rank2 - (out_rank - index)]

assert dim1 == dim2 or dim1 == 1 or dim2 == 1, "Invalid broadcast shapes"
bcast_shape_tensor[index] = max(dim1, dim2)

return bcast_shape_tensor


@_reg.register_shape_func("where", False)
def where_shape_func(attrs, inputs, _):
"""
Shape func for where.
"""
cond_shape = inputs[0]
x_shape = inputs[1]
out_shape = x_shape if x_shape.shape else cond_shape
y_shape = inputs[2]

bcast_shape = _broadcast_shape_tensors(x_shape, y_shape)
out_shape = _broadcast_shape_tensors(bcast_shape, cond_shape)

return [topi.math.identity(out_shape)]
return [out_shape]
17 changes: 9 additions & 8 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,25 +649,26 @@ def where(condition, x, y):
condition.
.. note::
The shape of condition, x, and y needs to be the same.
Shapes of condition, x, and y must be broadcastable to a common shape.
Semantics follow numpy where function
https://numpy.org/doc/stable/reference/generated/numpy.where.html
Parameters
----------
condition : relay.Expr
The condition array. The n-th element in `y` is selected when the n-th
value in the `condition` array is zero. Otherwise, the corresponding
element from `x` will be picked.
Where True, yield x, otherwise yield y
x : relay.Expr
The first array to be selected.
The first array or scalar to be selected.
y : relay.Expr
The second array to be selected.
The second array or scalar to be selected.
Returns
-------
result : relay.Expr
The selected array.
The selected array. The output shape is the broadcasted shape from
condition, x, and y.
Examples
--------
Expand All @@ -678,7 +679,7 @@ def where(condition, x, y):
condition = [[0, 1], [-1, 0]]
relay.where(conditon, x, y) = [[5, 2], [3, 8]]
condition = [1, 0]
condition = [[1], [0]]
relay.where(conditon, x, y) = [[1, 2], [7, 8]]
"""
return _make.where(condition, x, y)
Expand Down
47 changes: 14 additions & 33 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "../../transforms/pattern_utils.h"
#include "../make_op.h"
#include "../op_common.h"
#include "../type_relations.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -1685,30 +1686,17 @@ bool WhereRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
return false;
}

const auto& cond_shape = condition->shape;
const auto& x_shape = x->shape;
const auto& y_shape = y->shape;
ICHECK(x_shape.size() == y_shape.size()) << "x and y must have the same size";
ICHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs "
<< y->dtype;

if (cond_shape.size() != x_shape.size()) {
ICHECK_EQ(cond_shape.size(), 1) << "Shape of condition " << condition->shape
<< " must be either equal to x or has dimension of 1.";
}
for (size_t i = 0; i < x_shape.size(); i++) {
ICHECK(reporter->AssertEQ(x_shape[i], y_shape[i]))
<< "x and y must have the same shape: " << x_shape << " vs " << y_shape;
auto tensor_ty_condition = GetRef<TensorType>(condition);
auto tensor_ty_x = GetRef<TensorType>(x);
auto tensor_ty_y = GetRef<TensorType>(y);

if (i < cond_shape.size()) {
ICHECK(reporter->AssertEQ(cond_shape[i], x_shape[i]))
<< "condition and x must have the same shape: " << cond_shape << " vs " << x_shape;
}
}
if (x_shape.size() == 0) {
// if x and y are scalar, the condition shape becomes the output shape
reporter->Assign(types[3], TensorType(cond_shape, x->dtype));
} else {
reporter->Assign(types[3], TensorType(x_shape, x->dtype));
}
auto b_ty = ConcreteBroadcast(tensor_ty_x, tensor_ty_y, x->dtype);
auto ret_ty = ConcreteBroadcast(tensor_ty_condition, b_ty, b_ty->dtype);

reporter->Assign(types[3], ret_ty);
return true;
}

Expand All @@ -1731,17 +1719,10 @@ Return the elements, either from x or y, depending on the condition.
Given three ndarrays, condition, x, and y, return an ndarray with the elements
from x or y, depending on the elements from condition are true or false.
x and y must have the same shape. If condition has the same shape as x,
each element in the output array is from x if the corresponding element
in the condition is true, and from y if false.
If condition does not have the same shape as x, it must be a 1D array whose
size is the same as x’s first dimension size. Each row of the output array
is from x’s row if the corresponding element from condition is true, and
from y’s row if false.
When x and y are scalars, condition must be an 1D array. The output shape
is the same as condition's shape.
Shapes of condition, x, and y must be broadcastable to a common shape, which
is the output shape of this op. Semantics follow numpy where function.
https://numpy.org/doc/stable/reference/generated/numpy.where.html
Note that all non-zero values are interpreted as True in condition.
Expand All @@ -1753,7 +1734,7 @@ Examples::
where(cond, x, y) = [[5, 2], [3, 8]]
cond = [1, 0]
cond = [[1], [0]]
where(cond, x, y) = [[1, 2], [7, 8]]
cond = [0, 1]
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/type_relations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ bool EqualConstInt(const IndexExpr& lhs, int64_t value) {
return false;
}

Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype) {
TensorType ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype) {
std::vector<IndexExpr> oshape;
size_t ndim1 = t1->shape.size();
size_t ndim2 = t2->shape.size();
Expand Down
9 changes: 9 additions & 0 deletions src/relay/op/type_relations.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ bool IdentityRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
bool BroadcastRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter);

/*!
* \brief Determine the broadcasted shape from two input shapes
* \param t1 One of two Tensortype whose shapes are broadcasted
* \param t2 One of two Tensortype whose shapes are broadcasted
* \param output_dtype dtype of the output TensorType
* \return A TensorType whose shape is broadcasted from two input TensorType.
*/
TensorType ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype);

/*!
* \brief The broadcast type relation, implements the broadcasting
* rule over the two input types producing the broadcasted type.
Expand Down
2 changes: 2 additions & 0 deletions tests/python/frontend/pytorch/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ def test_custom_lstm():
num_layers = 3
state_tensor_shape = (batch, hidden_size)

torch.manual_seed(1)

inp = torch.randn(seq_len, batch, input_size)

input_shapes = [
Expand Down
44 changes: 44 additions & 0 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,5 +1236,49 @@ def test_any_stack():
verify_any_stack(any_dims(4), (2, 1, 1, 4), 2, 2)


def verify_any_where(
cond_shape, x_shape, y_shape, cond_np_shape, x_np_shape, y_np_shape, y_np_shape_invalid=None
):
dtype = "float32"
cond = relay.var("cond", shape=cond_shape, dtype="bool")
x = relay.var("x", shape=x_shape, dtype=dtype)
y = relay.var("y", shape=y_shape, dtype=dtype)
z = relay.where(cond, x, y)
mod = tvm.IRModule()
mod["main"] = relay.Function([cond, x, y], z)

cond_np = np.random.randn(*cond_np_shape) > 0
x_np = np.random.randn(*x_np_shape).astype(dtype)
y_np = np.random.randn(*y_np_shape).astype(dtype)
expected = np.where(cond_np, x_np, y_np)

check_result([cond_np, x_np, y_np], mod, expected)

# verify invalid broadcasting check
if y_np_shape_invalid:
y_np_bad = np.random.randn(*y_np_shape_invalid).astype(dtype)
try:
check_result([cond_np, x_np, y_np_bad], mod, expected)
except tvm.error.TVMError as e:
error_msg = str(e).split("\n")[-1]
assert "Invalid broadcast shapes" in error_msg


@tvm.testing.uses_gpu
def test_any_where():
verify_any_where(any_dims(1), (5,), (5,), (5,), (5,), (5,))
verify_any_where(any_dims(1), any_dims(1), (5,), (5,), (5,), (5,))
verify_any_where(any_dims(1), any_dims(1), any_dims(1), (5,), (5,), (5,))
verify_any_where((5,), any_dims(1), any_dims(1), (5,), (5,), (5,))

# where with broadcast
verify_any_where(any_dims(1), any_dims(1), any_dims(1), (5,), (1,), (5,))
verify_any_where(any_dims(1), any_dims(2), any_dims(2), (5,), (5, 5), (5, 5))
verify_any_where(any_dims(1), any_dims(1), any_dims(2), (5,), (5,), (5, 5))
verify_any_where(
any_dims(2), any_dims(2), any_dims(2), (3, 4), (3, 1), (1, 4), y_np_shape_invalid=(2, 4)
)


if __name__ == "__main__":
pytest.main([__file__])
Loading

0 comments on commit 5b31545

Please sign in to comment.