Skip to content

Commit

Permalink
[Relay/topi] Support scalar inputs in where op (#6383)
Browse files Browse the repository at this point in the history
* support where with scalars

* add test for where with scalar

* add comment
  • Loading branch information
masahi committed Sep 4, 2020
1 parent 86fa81c commit 8508ec3
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 16 deletions.
21 changes: 13 additions & 8 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -891,16 +891,22 @@ inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y,
<< " vs " << y->shape.size();
CHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs "
<< y->dtype;
Array<PrimExpr> oshape = x->shape;
Tensor out;

if (condition->shape.size() != 1) {
if (x->shape.size() == 0) {
return compute(
condition->shape,
[&](const Array<Var>& indices) {
Array<PrimExpr> condition_idx{indices[0]};
return tvm::tir::Select(condition(condition_idx) != 0, x(), y());
},
name, tag);
} else if (condition->shape.size() != 1) {
CHECK_EQ(condition->shape.size(), x->shape.size())
<< "condition array must be either have the same shape as x or to be a "
"1-D array.Got different number of dimension: "
<< condition->shape.size() << " vs " << x->shape.size();
out = compute(
oshape,
return compute(
x->shape,
[&](const Array<Var>& indices) {
return tvm::tir::Select(condition(indices) != 0, x(indices), y(indices));
},
Expand All @@ -909,15 +915,14 @@ inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y,
CHECK_EQ(topi::GetConstInt(condition->shape[0]), topi::GetConstInt(x->shape[0]))
<< "If condition is 1-D, the first dimension must be the same as x: " << condition->shape[0]
<< " vs " << x->shape[0];
out = compute(
oshape,
return compute(
x->shape,
[&](const Array<Var>& indices) {
Array<PrimExpr> condition_idx{indices[0]};
return tvm::tir::Select(condition(condition_idx) != 0, x(indices), y(indices));
},
name, tag);
}
return out;
}

/*!
Expand Down
13 changes: 12 additions & 1 deletion src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1662,7 +1662,12 @@ bool WhereRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
<< "condition and x must have the same shape: " << cond_shape << " vs " << x_shape;
}
}
reporter->Assign(types[3], TensorType(x_shape, x->dtype));
if (x_shape.size() == 0) {
// if x and y are scalar, the condition shape becomes the output shape
reporter->Assign(types[3], TensorType(cond_shape, x->dtype));
} else {
reporter->Assign(types[3], TensorType(x_shape, x->dtype));
}
return true;
}

Expand Down Expand Up @@ -1694,6 +1699,9 @@ size is the same as x’s first dimension size. Each row of the output array
is from x’s row if the corresponding element from condition is true, and
from y’s row if false.
When x and y are scalars, condition must be an 1D array. The output shape
is the same as condition's shape.
Note that all non-zero values are interpreted as True in condition.
Examples::
Expand All @@ -1707,6 +1715,9 @@ Examples::
cond = [1, 0]
where(cond, x, y) = [[1, 2], [7, 8]]
cond = [0, 1]
where(cond, 1, -1) = [-1, 1]
)code" TVM_ADD_FILELINE)
.add_argument("condition", "Tensor", "Condition array")
.add_argument("x", "Tensor", "First array to be selected")
Expand Down
31 changes: 24 additions & 7 deletions tests/python/relay/test_op_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ def test_binary_int_broadcast_2():

@tvm.testing.uses_gpu
def test_where():
def run(func, inputs, ref_res):
for target, ctx in tvm.testing.enabled_targets():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(*inputs)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)

shape = (3, 4)
dtype = "float32"
cond = relay.var("cond", relay.TensorType(shape, dtype))
Expand All @@ -158,11 +165,21 @@ def test_where():
x = np.random.uniform(size=shape).astype(dtype)
y = np.random.uniform(size=shape).astype(dtype)
ref_res = np.where(condition, x, y)
for target, ctx in tvm.testing.enabled_targets():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(condition, x, y)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)

run(func, [condition, x, y], ref_res)

x = relay.const(1)
y = relay.const(-1)
shape = (3,)
dtype = "float32"
cond = relay.var("cond", relay.TensorType(shape, "bool"))
z = relay.where(cond, x, y)

func = relay.Function([cond], z)
condition = np.array([1, 0, 1], dtype=np.bool)
ref_res = np.where(condition, 1, -1)

run(func, [condition], ref_res)


def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"):
Expand Down Expand Up @@ -232,12 +249,12 @@ def _np_log_sum_exp(x, axis, keepdims=False):
if not keepdims:
x = np.squeeze(x, axis=axis)
return x

def _unbiased_relay_wrapper(f):
def _unbiased_func(x, axis=None, keepdims=False, exclude=False):
return f(x, axis=axis, keepdims=keepdims, exclude=exclude, unbiased=True)
return _unbiased_func

def _unbiased_np_wrapper(f):
def _unbiased_func(a, axis=None, dtype=None, keepdims=None):
return f(a, axis=axis, dtype=dtype, ddof=1, keepdims=keepdims)
Expand Down

0 comments on commit 8508ec3

Please sign in to comment.