Skip to content

Commit

Permalink
support dynamic scatter nd
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 27, 2021
1 parent d4a4db8 commit 96c9231
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
8 changes: 5 additions & 3 deletions python/tvm/topi/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
"""Scatter operator"""
from ..tir import decl_buffer, ir_builder, AssertStmt, StringImm, Evaluate
from ..tir import decl_buffer, ir_builder, AssertStmt, StringImm, Evaluate, expr
from ..te import extern, hybrid


Expand Down Expand Up @@ -200,20 +200,22 @@ def scatter(data, indices, updates, axis=0):


def _verify_scatter_nd_inputs(data, indices, updates):
# TODO(masahi): revisit
return
mdim = int(indices.shape[0])
assert mdim <= len(data.shape), (
f"The first dimension of the indices ({mdim}) must be less than or equal to "
f"the length of the shape of the output ({len(shape)})."
)
for i in range(len(indices.shape) - 1):
if isinstance(indices.shape[i + 1], expr.Var) or isinstance(updates.shape[i], expr.Var):
continue
assert indices.shape[i + 1] == updates.shape[i], (
f"Dimension of indices[{i+1}] ({indices.shape[i+1]}) must equal dimension of "
f"updates[{i}] ({updates.shape[i]})."
)
for i in range(mdim, len(data.shape)):
data_ind = i - mdim + len(indices.shape) - 1
if isinstance(updates.shape[data_ind], expr.Var) or isinstance(data.shape[i], expr.Var):
continue
assert updates.shape[data_ind] == data.shape[i], (
f"Dimension of updates[{data_ind}] ({updates.shape[data_ind]}) must equal dimension "
f"of out_shape[{i}] ({data.shape[i]})."
Expand Down
23 changes: 23 additions & 0 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -1728,5 +1728,28 @@ def verify_gather_nd(data_shape, indices_shape, data_shape_np, indices_shape_np,
)


@tvm.testing.uses_gpu
def test_scatter_nd():
def verify_scatter_nd(data_np, indices_np, updates_np, ref_res):
indices_shape = (2, relay.Any())
updates_shape = (relay.Any(),)
data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype))
indices = relay.var("indices", relay.TensorType(indices_shape, str(indices_np.dtype)))
updates = relay.var("updates", relay.TensorType(updates_shape, str(updates_np.dtype)))

out = relay.op.scatter_nd(data, indices, updates, "add")

mod = tvm.IRModule()
mod["main"] = relay.Function([data, indices, updates], out)

check_result([data_np, indices_np, updates_np], mod, [ref_res])

data = np.zeros((2, 2)).astype("int64")
indices = np.array([[1, 1, 0], [0, 1, 0]])
updates = np.array([2, 3, 0])
out = np.array([[0, 0], [2, 3]])
verify_scatter_nd(data, indices, updates, out)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 96c9231

Please sign in to comment.