Skip to content

Commit

Permalink
add where any test
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Oct 26, 2020
1 parent 1e9ca8b commit 5423c70
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 3 deletions.
40 changes: 40 additions & 0 deletions include/tvm/topi/broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,46 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t,
return tvm::te::compute(oshape, l, name, tag);
}

inline tvm::te::Tensor broadcast_shape_tensors(const tvm::te::Tensor& shape_tensor1,
const tvm::te::Tensor& shape_tensor2,
std::string name = "T_broadcast_shape_tensors",
std::string tag = kBroadcast) {
const auto rank1 = detail::GetConstInt(shape_tensor1->shape[0]);
const auto rank2 = detail::GetConstInt(shape_tensor2->shape[0]);
const auto out_rank = std::max<int32_t>(rank1, rank2);
const tvm::PrimExpr one = tvm::cast(shape_tensor1->dtype, PrimExpr(1));

auto select_dim = [&](const tvm::te::Tensor& shape_tensor, int rank,
tvm::tir::Var index) -> PrimExpr {
if (rank < out_rank) {
// if the rank is smaller, dimension 1 is prepended according to
// the numpy broadcasting semantics.
return tvm::tir::Select(rank - (out_rank - index) < 0, one,
shape_tensor[rank - (out_rank - index)]);
} else {
// rank == out_rank, safe to index directly
return shape_tensor[index];
}
};

auto func = [&](tvm::Array<tvm::tir::Var> ovars) {
auto index = ovars[0];
PrimExpr dim1 = select_dim(shape_tensor1, rank1, index);
PrimExpr dim2 = select_dim(shape_tensor2, rank2, index);
if (topi::detail::EqualCheck(one, dim1)) {
return dim2;
} else if (topi::detail::EqualCheck(one, dim2)) {
return dim1;
}
return tvm::max(dim1, dim2);
};

Array<PrimExpr> oshape;
oshape.push_back(PrimExpr(out_rank));

return tvm::te::compute(oshape, func, name, tag);
}

#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \
inline tvm::PrimExpr Name(const tvm::PrimExpr& a, const tvm::PrimExpr& b) { ComputeRule; } \
inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B, \
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,8 @@ def where_shape_func(attrs, inputs, _):
"""
cond_shape = inputs[0]
x_shape = inputs[1]
out_shape = x_shape if x_shape.shape else cond_shape
y_shape = inputs[2]
bcast_shape = topi.broadcast.broadcast_shape_tensors(x_shape, y_shape)
out_shape = topi.broadcast.broadcast_shape_tensors(bcast_shape, cond_shape)

return [topi.math.identity(out_shape)]
return [out_shape]
22 changes: 21 additions & 1 deletion python/tvm/topi/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
def broadcast_to(data, shape):
"""Broadcast the src to the target shape
We follows the numpy broadcasting rule.
We follow the numpy broadcasting rule.
See also https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
Parameters
Expand All @@ -40,6 +40,26 @@ def broadcast_to(data, shape):
return _cpp.broadcast_to(data, shape)


def broadcast_shape_tensors(shape_tensor1, shape_tensor2):
""" Compute a shape tensor whose values represents the broadcasted shape
of two input shape tensors
Parameters
----------
shape_tensor1 : tvm.te.Tensor
One of input shape tensors
shape_tensor2 : tvm.te.Tensor
One of input shape tensors
Returns
-------
ret : tvm.te.Tensor
A shape tensor whose values represents the broadcasted shape
"""
return _cpp.broadcast_shape_tensors(shape_tensor1, shape_tensor2)


def add(lhs, rhs):
"""Addition with auto-broadcasting
Expand Down
4 changes: 4 additions & 0 deletions src/topi/broadcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,9 @@ TVM_REGISTER_GLOBAL("topi.broadcast_to").set_body([](TVMArgs args, TVMRetValue*
*rv = broadcast_to(args[0], args[1]);
});

TVM_REGISTER_GLOBAL("topi.broadcast_shape_tensors").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = broadcast_shape_tensors(args[0], args[1]);
});

} // namespace topi
} // namespace tvm
31 changes: 31 additions & 0 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,5 +1236,36 @@ def test_any_stack():
verify_any_stack(any_dims(4), (2, 1, 1, 4), 2, 2)


def verify_any_where(cond_shape, x_shape, y_shape, cond_np_shape, x_np_shape, y_np_shape):
dtype = "float32"
cond = relay.var("cond", shape=cond_shape, dtype="bool")
x = relay.var("x", shape=x_shape, dtype=dtype)
y = relay.var("y", shape=y_shape, dtype=dtype)
z = relay.where(cond, x, y)
mod = tvm.IRModule()
mod["main"] = relay.Function([cond, x, y], z)

cond_np = np.random.randn(*cond_np_shape) > 0
x_np = np.random.randn(*x_np_shape).astype(dtype)
y_np = np.random.randn(*y_np_shape).astype(dtype)
expected = np.where(cond_np, x_np, y_np)

check_result([cond_np, x_np, y_np], mod, expected)


@tvm.testing.uses_gpu
def test_any_where():
verify_any_where(any_dims(1), (5,), (5,), (5,), (5,), (5,))
verify_any_where(any_dims(1), any_dims(1), (5,), (5,), (5,), (5,))
verify_any_where(any_dims(1), any_dims(1), any_dims(1), (5,), (5,), (5,))
verify_any_where((5,), any_dims(1), any_dims(1), (5,), (5,), (5,))

# where with broadcast
verify_any_where(any_dims(1), any_dims(1), any_dims(1), (5,), (1,), (5,))
verify_any_where(any_dims(1), any_dims(2), any_dims(2), (5,), (5, 5), (5, 5))
verify_any_where(any_dims(1), any_dims(1), any_dims(2), (5,), (5,), (5, 5))
verify_any_where(any_dims(2), any_dims(2), any_dims(2), (3, 4), (3, 1), (1, 4))


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 5423c70

Please sign in to comment.