diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index b09b035b6bc3..af59928641ba 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -891,16 +891,22 @@ inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y, << " vs " << y->shape.size(); CHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs " << y->dtype; - Array oshape = x->shape; - Tensor out; - if (condition->shape.size() != 1) { + if (x->shape.size() == 0) { + return compute( + condition->shape, + [&](const Array& indices) { + Array condition_idx{indices[0]}; + return tvm::tir::Select(condition(condition_idx) != 0, x(), y()); + }, + name, tag); + } else if (condition->shape.size() != 1) { CHECK_EQ(condition->shape.size(), x->shape.size()) << "condition array must be either have the same shape as x or to be a " "1-D array.Got different number of dimension: " << condition->shape.size() << " vs " << x->shape.size(); - out = compute( - oshape, + return compute( + x->shape, [&](const Array& indices) { return tvm::tir::Select(condition(indices) != 0, x(indices), y(indices)); }, @@ -909,15 +915,14 @@ inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y, CHECK_EQ(topi::GetConstInt(condition->shape[0]), topi::GetConstInt(x->shape[0])) << "If condition is 1-D, the first dimension must be the same as x: " << condition->shape[0] << " vs " << x->shape[0]; - out = compute( - oshape, + return compute( + x->shape, [&](const Array& indices) { Array condition_idx{indices[0]}; return tvm::tir::Select(condition(condition_idx) != 0, x(indices), y(indices)); }, name, tag); } - return out; } /*! diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 5126b1d34d3e..40051e43d57b 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1662,7 +1662,12 @@ bool WhereRel(const Array& types, int num_inputs, const Attrs& attrs, << "condition and x must have the same shape: " << cond_shape << " vs " << x_shape; } } - reporter->Assign(types[3], TensorType(x_shape, x->dtype)); + if (x_shape.size() == 0) { + // if x and y are scalar, the condition shape becomes the output shape + reporter->Assign(types[3], TensorType(cond_shape, x->dtype)); + } else { + reporter->Assign(types[3], TensorType(x_shape, x->dtype)); + } return true; } @@ -1694,6 +1699,9 @@ size is the same as x’s first dimension size. Each row of the output array is from x’s row if the corresponding element from condition is true, and from y’s row if false. +When x and y are scalars, condition must be an 1D array. The output shape +is the same as condition's shape. + Note that all non-zero values are interpreted as True in condition. Examples:: @@ -1707,6 +1715,9 @@ Examples:: cond = [1, 0] where(cond, x, y) = [[1, 2], [7, 8]] + cond = [0, 1] + where(cond, 1, -1) = [-1, 1] + )code" TVM_ADD_FILELINE) .add_argument("condition", "Tensor", "Condition array") .add_argument("x", "Tensor", "First array to be selected") diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index af3826448edb..8c62f8c0727f 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -144,6 +144,13 @@ def test_binary_int_broadcast_2(): @tvm.testing.uses_gpu def test_where(): + def run(func, inputs, ref_res): + for target, ctx in tvm.testing.enabled_targets(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(*inputs) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + shape = (3, 4) dtype = "float32" cond = relay.var("cond", relay.TensorType(shape, dtype)) @@ -158,11 +165,21 @@ def test_where(): x = np.random.uniform(size=shape).astype(dtype) y = np.random.uniform(size=shape).astype(dtype) ref_res = np.where(condition, x, y) - for target, ctx in tvm.testing.enabled_targets(): - for kind in ["graph", "debug"]: - intrp = relay.create_executor(kind, ctx=ctx, target=target) - op_res = intrp.evaluate(func)(condition, x, y) - tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + + run(func, [condition, x, y], ref_res) + + x = relay.const(1) + y = relay.const(-1) + shape = (3,) + dtype = "float32" + cond = relay.var("cond", relay.TensorType(shape, "bool")) + z = relay.where(cond, x, y) + + func = relay.Function([cond], z) + condition = np.array([1, 0, 1], dtype=np.bool) + ref_res = np.where(condition, 1, -1) + + run(func, [condition], ref_res) def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"): @@ -232,12 +249,12 @@ def _np_log_sum_exp(x, axis, keepdims=False): if not keepdims: x = np.squeeze(x, axis=axis) return x - + def _unbiased_relay_wrapper(f): def _unbiased_func(x, axis=None, keepdims=False, exclude=False): return f(x, axis=axis, keepdims=keepdims, exclude=exclude, unbiased=True) return _unbiased_func - + def _unbiased_np_wrapper(f): def _unbiased_func(a, axis=None, dtype=None, keepdims=None): return f(a, axis=axis, dtype=dtype, ddof=1, keepdims=keepdims)