diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index fa27faf18f15..b670755d97b7 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -39,6 +39,8 @@ #include #include +#include "detail/broadcast.h" + namespace tvm { namespace topi { @@ -887,53 +889,30 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int axis, std::string */ inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y, std::string name = "T_where", std::string tag = kBroadcast) { - ICHECK_EQ(x->shape.size(), y->shape.size()) - << "x and y must have the same shape.Got different number of dimension: " << x->shape.size() - << " vs " << y->shape.size(); ICHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs " << y->dtype; + auto get_out_shape = [&]() { + auto bh1 = detail::BroadcastShape(x->shape, y->shape); + Array common_shape1(bh1.common_shape.begin(), bh1.common_shape.end()); + auto bh2 = detail::BroadcastShape(condition->shape, common_shape1); + Array common_shape2(bh2.common_shape.begin(), bh2.common_shape.end()); + return common_shape2; + }; - if (x->shape.size() == 0) { - return compute( - condition->shape, - [&](const Array& indices) { - PrimExpr cond; - if (condition->shape.size() == 0) { - cond = condition(); - } else { - Array condition_idx{indices[0]}; - cond = condition(condition_idx); - } - return tvm::tir::Select(cond != 0, x(), y()); - }, - name, tag); - } else if (condition->shape.size() != 1) { - ICHECK_EQ(condition->shape.size(), x->shape.size()) - << "condition array must be either have the same shape as x or to be a " - "1-D array.Got different number of dimension: " - << condition->shape.size() << " vs " << x->shape.size(); - return compute( - x->shape, - [&](const Array& indices) { - return tvm::tir::Select(condition(indices) != 0, x(indices), y(indices)); - }, - name, tag); - } else { - int64_t cond_first_dim = topi::GetConstInt(condition->shape[0]); - int64_t x_first_dim = topi::GetConstInt(x->shape[0]); - if (cond_first_dim > 0 && x_first_dim > 0) { - ICHECK_EQ(cond_first_dim, x_first_dim) - << "If condition is 1-D, the first dimension must be the same as x: " << cond_first_dim - << " vs " << x_first_dim; - } - return compute( - x->shape, - [&](const Array& indices) { - Array condition_idx{indices[0]}; - return tvm::tir::Select(condition(condition_idx) != 0, x(indices), y(indices)); - }, - name, tag); - } + auto oshape = get_out_shape(); + + auto c_bh = detail::BroadcastShape(condition->shape, oshape); + auto x_bh = detail::BroadcastShape(x->shape, oshape); + auto y_bh = detail::BroadcastShape(y->shape, oshape); + + auto select = [&](tvm::Array ovars) { + auto c = condition(InputIndexFromBroadcast(ovars, condition, c_bh.vars1, c_bh.all_vars)); + auto true_val = x(InputIndexFromBroadcast(ovars, x, x_bh.vars1, x_bh.all_vars)); + auto false_val = y(InputIndexFromBroadcast(ovars, y, y_bh.vars1, y_bh.all_vars)); + return tvm::tir::Select(c != 0, true_val, false_val); + }; + + return compute(oshape, select, name, tag); } /*! diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index b135901baac3..3bb488c80b58 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -821,6 +821,33 @@ def stack_shape_func(attrs, inputs, _): return [_stack_shape_func(inputs[0], convert(axis), convert(len(inputs)))] +@script +def _broadcast_shape_tensors(shape_tensor1, shape_tensor2): + rank1 = shape_tensor1.shape[0] + rank2 = shape_tensor2.shape[0] + out_rank = max(rank1, rank2) + bcast_shape_tensor = output_tensor((out_rank,), "int64") + + for index in const_range(out_rank): + dim1 = int64(1) + dim2 = int64(1) + + if rank1 == out_rank: + dim1 = shape_tensor1[index] + elif rank1 - (out_rank - index) >= 0: + dim1 = shape_tensor1[rank1 - (out_rank - index)] + + if rank2 == out_rank: + dim2 = shape_tensor2[index] + elif rank2 - (out_rank - index) >= 0: + dim2 = shape_tensor2[rank2 - (out_rank - index)] + + assert dim1 == dim2 or dim1 == 1 or dim2 == 1, "Invalid broadcast shapes" + bcast_shape_tensor[index] = max(dim1, dim2) + + return bcast_shape_tensor + + @_reg.register_shape_func("where", False) def where_shape_func(attrs, inputs, _): """ @@ -828,6 +855,9 @@ def where_shape_func(attrs, inputs, _): """ cond_shape = inputs[0] x_shape = inputs[1] - out_shape = x_shape if x_shape.shape else cond_shape + y_shape = inputs[2] + + bcast_shape = _broadcast_shape_tensors(x_shape, y_shape) + out_shape = _broadcast_shape_tensors(bcast_shape, cond_shape) - return [topi.math.identity(out_shape)] + return [out_shape] diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 01af60ebbd4b..17f4c02380b3 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -688,25 +688,26 @@ def where(condition, x, y): condition. .. note:: - The shape of condition, x, and y needs to be the same. + Shapes of condition, x, and y must be broadcastable to a common shape. + Semantics follow numpy where function + https://numpy.org/doc/stable/reference/generated/numpy.where.html Parameters ---------- condition : relay.Expr - The condition array. The n-th element in `y` is selected when the n-th - value in the `condition` array is zero. Otherwise, the corresponding - element from `x` will be picked. + Where True, yield x, otherwise yield y x : relay.Expr - The first array to be selected. + The first array or scalar to be selected. y : relay.Expr - The second array to be selected. + The second array or scalar to be selected. Returns ------- result : relay.Expr - The selected array. + The selected array. The output shape is the broadcasted shape from + condition, x, and y. Examples -------- @@ -717,7 +718,7 @@ def where(condition, x, y): condition = [[0, 1], [-1, 0]] relay.where(conditon, x, y) = [[5, 2], [3, 8]] - condition = [1, 0] + condition = [[1], [0]] relay.where(conditon, x, y) = [[1, 2], [7, 8]] """ return _make.where(condition, x, y) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 4a832ec8d962..02fd8930d332 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -45,6 +45,7 @@ #include "../../transforms/pattern_utils.h" #include "../make_op.h" #include "../op_common.h" +#include "../type_relations.h" namespace tvm { namespace relay { @@ -1737,30 +1738,17 @@ bool WhereRel(const Array& types, int num_inputs, const Attrs& attrs, return false; } - const auto& cond_shape = condition->shape; - const auto& x_shape = x->shape; - const auto& y_shape = y->shape; - ICHECK(x_shape.size() == y_shape.size()) << "x and y must have the same size"; + ICHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs " + << y->dtype; - if (cond_shape.size() != x_shape.size()) { - ICHECK_EQ(cond_shape.size(), 1) << "Shape of condition " << condition->shape - << " must be either equal to x or has dimension of 1."; - } - for (size_t i = 0; i < x_shape.size(); i++) { - ICHECK(reporter->AssertEQ(x_shape[i], y_shape[i])) - << "x and y must have the same shape: " << x_shape << " vs " << y_shape; + auto tensor_ty_condition = GetRef(condition); + auto tensor_ty_x = GetRef(x); + auto tensor_ty_y = GetRef(y); - if (i < cond_shape.size()) { - ICHECK(reporter->AssertEQ(cond_shape[i], x_shape[i])) - << "condition and x must have the same shape: " << cond_shape << " vs " << x_shape; - } - } - if (x_shape.size() == 0) { - // if x and y are scalar, the condition shape becomes the output shape - reporter->Assign(types[3], TensorType(cond_shape, x->dtype)); - } else { - reporter->Assign(types[3], TensorType(x_shape, x->dtype)); - } + auto b_ty = ConcreteBroadcast(tensor_ty_x, tensor_ty_y, x->dtype); + auto ret_ty = ConcreteBroadcast(tensor_ty_condition, b_ty, b_ty->dtype); + + reporter->Assign(types[3], ret_ty); return true; } @@ -1783,17 +1771,10 @@ Return the elements, either from x or y, depending on the condition. Given three ndarrays, condition, x, and y, return an ndarray with the elements from x or y, depending on the elements from condition are true or false. -x and y must have the same shape. If condition has the same shape as x, -each element in the output array is from x if the corresponding element -in the condition is true, and from y if false. - -If condition does not have the same shape as x, it must be a 1D array whose -size is the same as x’s first dimension size. Each row of the output array -is from x’s row if the corresponding element from condition is true, and -from y’s row if false. -When x and y are scalars, condition must be an 1D array. The output shape -is the same as condition's shape. +Shapes of condition, x, and y must be broadcastable to a common shape, which +is the output shape of this op. Semantics follow numpy where function. +https://numpy.org/doc/stable/reference/generated/numpy.where.html Note that all non-zero values are interpreted as True in condition. @@ -1805,7 +1786,7 @@ Examples:: where(cond, x, y) = [[5, 2], [3, 8]] - cond = [1, 0] + cond = [[1], [0]] where(cond, x, y) = [[1, 2], [7, 8]] cond = [0, 1] diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 3dc33c5022e0..7a3bfcb21ce6 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -64,7 +64,7 @@ bool EqualConstInt(const IndexExpr& lhs, int64_t value) { return false; } -Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype) { +TensorType ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype) { std::vector oshape; size_t ndim1 = t1->shape.size(); size_t ndim2 = t2->shape.size(); diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h index 5ab8b121ae9d..6d6d5f70c0c2 100644 --- a/src/relay/op/type_relations.h +++ b/src/relay/op/type_relations.h @@ -57,6 +57,15 @@ bool IdentityRel(const Array& types, int num_inputs, const Attrs& attrs, bool BroadcastRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter); +/*! + * \brief Determine the broadcasted shape from two input shapes + * \param t1 One of two Tensortype whose shapes are broadcasted + * \param t2 One of two Tensortype whose shapes are broadcasted + * \param output_dtype dtype of the output TensorType + * \return A TensorType whose shape is broadcasted from two input TensorType. + */ +TensorType ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype); + /*! * \brief The broadcast type relation, implements the broadcasting * rule over the two input types producing the broadcasted type. diff --git a/tests/python/frontend/pytorch/test_lstm.py b/tests/python/frontend/pytorch/test_lstm.py index 39d78c70c0fb..1197990f54ba 100644 --- a/tests/python/frontend/pytorch/test_lstm.py +++ b/tests/python/frontend/pytorch/test_lstm.py @@ -277,6 +277,8 @@ def test_custom_lstm(): num_layers = 3 state_tensor_shape = (batch, hidden_size) + torch.manual_seed(1) + inp = torch.randn(seq_len, batch, input_size) input_shapes = [ diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 872728514c3e..b1b068ebb32a 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -1236,5 +1236,49 @@ def test_any_stack(): verify_any_stack(any_dims(4), (2, 1, 1, 4), 2, 2) +def verify_any_where( + cond_shape, x_shape, y_shape, cond_np_shape, x_np_shape, y_np_shape, y_np_shape_invalid=None +): + dtype = "float32" + cond = relay.var("cond", shape=cond_shape, dtype="bool") + x = relay.var("x", shape=x_shape, dtype=dtype) + y = relay.var("y", shape=y_shape, dtype=dtype) + z = relay.where(cond, x, y) + mod = tvm.IRModule() + mod["main"] = relay.Function([cond, x, y], z) + + cond_np = np.random.randn(*cond_np_shape) > 0 + x_np = np.random.randn(*x_np_shape).astype(dtype) + y_np = np.random.randn(*y_np_shape).astype(dtype) + expected = np.where(cond_np, x_np, y_np) + + check_result([cond_np, x_np, y_np], mod, expected) + + # verify invalid broadcasting check + if y_np_shape_invalid: + y_np_bad = np.random.randn(*y_np_shape_invalid).astype(dtype) + try: + check_result([cond_np, x_np, y_np_bad], mod, expected) + except tvm.error.TVMError as e: + error_msg = str(e).split("\n")[-1] + assert "Invalid broadcast shapes" in error_msg + + +@tvm.testing.uses_gpu +def test_any_where(): + verify_any_where(any_dims(1), (5,), (5,), (5,), (5,), (5,)) + verify_any_where(any_dims(1), any_dims(1), (5,), (5,), (5,), (5,)) + verify_any_where(any_dims(1), any_dims(1), any_dims(1), (5,), (5,), (5,)) + verify_any_where((5,), any_dims(1), any_dims(1), (5,), (5,), (5,)) + + # where with broadcast + verify_any_where(any_dims(1), any_dims(1), any_dims(1), (5,), (1,), (5,)) + verify_any_where(any_dims(1), any_dims(2), any_dims(2), (5,), (5, 5), (5, 5)) + verify_any_where(any_dims(1), any_dims(1), any_dims(2), (5,), (5,), (5, 5)) + verify_any_where( + any_dims(2), any_dims(2), any_dims(2), (3, 4), (3, 1), (1, 4), y_np_shape_invalid=(2, 4) + ) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index eafc743634d8..ef363430a2eb 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -152,35 +152,70 @@ def run(func, inputs, ref_res): op_res = intrp.evaluate(func)(*inputs) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) - shape = (3, 4) - dtype = "float32" - cond = relay.var("cond", relay.TensorType(shape, dtype)) - x = relay.var("x", relay.TensorType(shape, dtype)) - y = relay.var("y", relay.TensorType(shape, dtype)) - z = relay.where(cond, x, y) - zz = run_infer_type(z) - assert zz.checked_type == relay.TensorType(shape, dtype) + def verify(x_np, y_np, cond_np): + ref_res = np.where(cond_np, x_np, y_np) + + args = [] + args_np = [] + vs = [] + + cond = relay.var("cond", relay.TensorType(cond_np.shape, "bool")) - func = relay.Function([cond, x, y], z) - condition = np.random.uniform(low=-1, high=1, size=shape).astype(dtype) - x = np.random.uniform(size=shape).astype(dtype) - y = np.random.uniform(size=shape).astype(dtype) - ref_res = np.where(condition, x, y) + args.append(cond) + args_np.append(cond_np) - run(func, [condition, x, y], ref_res) + for v_name, v_np in [("x", x_np), ("y", y_np)]: + if len(v_np.shape) == 0: + v = relay.const(v_np.item()) + else: + v = relay.var(v_name, relay.TensorType(v_np.shape, dtype)) + args.append(v) + args_np.append(v_np) + vs.append(v) + + z = relay.where(cond, vs[0], vs[1]) + + func = relay.Function(args, z) + + run(func, args_np, ref_res) - x = relay.const(1) - y = relay.const(-1) - shape = (3,) dtype = "float32" - cond = relay.var("cond", relay.TensorType(shape, "bool")) - z = relay.where(cond, x, y) - func = relay.Function([cond], z) - condition = np.array([1, 0, 1], dtype=np.bool) - ref_res = np.where(condition, 1, -1) + x_np = np.random.uniform(size=(3, 4)).astype(dtype) + y_np = np.random.uniform(size=(3, 4)).astype(dtype) + cond_np = np.random.uniform(low=-1, high=1, size=(3, 4)) > 0 + + verify(x_np, y_np, cond_np) + + x_np = np.array(1.0, dtype) + y_np = np.array(-1.0, dtype) + cond_np = np.array([1, 0, 1], dtype=np.bool) + + verify(x_np, y_np, cond_np) + + x_np = np.arange(10).astype(dtype) + y_np = 10 * x_np + cond_np = x_np < 5 + + verify(x_np, y_np, cond_np) + + x_np = np.array([[1, 2], [3, 4]], dtype) + y_np = np.array([[5, 6], [7, 8]], dtype) + cond_np = np.array([[1], [0]], dtype=np.bool) + + verify(x_np, y_np, cond_np) + verify(x_np, y_np, cond_np.T) + + x_np = np.random.randn(1, 12, 8, 8).astype(dtype) + y_np = np.array(-1.0, dtype) + cond_np = np.random.randn(1, 1, 8, 8) > 0 + + verify(x_np, y_np, cond_np) + + x_np, y_np = np.ogrid[:3, :4] + cond_np = np.where(x_np < y_np, x_np, 10 + y_np).astype(np.bool) - run(func, [condition], ref_res) + verify(x_np.astype(dtype), y_np.astype(dtype), cond_np) def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"):