Skip to content

Commit

Permalink
refactor gather_nd ref funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 27, 2021
1 parent 36a4501 commit 533854a
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 24 deletions.
6 changes: 2 additions & 4 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,14 +1132,12 @@ def unique_shape_func(attrs, inputs, _):
@script
def _gather_nd_shape(data_shape, indices_shape, batch_dims, gather_dim):
ndim = data_shape.shape[0]
mdim = gather_dim
# using mdim = indices_shape[0] wouldn't work because a rank cannot
# depend on a runtime shape dimension of indices tensor, even if the
# dimension is always a known, fixed value. As a workaround, we assume that
# the fixed gather dimension (the size of an indexing tuple) is recorded
# in `gather_nd` op attribute.
err_msg = "The recorded gather dimension and the actual dimension are different"
assert mdim == indices_shape[0], err_msg
mdim = gather_dim
kdim = indices_shape.shape[0] - 1
out_shape = output_tensor((kdim + ndim - (mdim + batch_dims),), "int64")
for i in range(1, kdim + 1):
Expand All @@ -1154,7 +1152,7 @@ def gather_nd_shape_func(attrs, inputs, _):
"""
Shape func for ghater_nd operator.
"""
batch_dims = get_const_int(attrs.batch_dimss)
batch_dims = get_const_int(attrs.batch_dims)
gather_dim = get_const_int(attrs.gather_dim)
assert gather_dim > 0, "gather_dim needs to be specified for dynamic gather_nd"
return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(gather_dim))]
22 changes: 2 additions & 20 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tvm.error import TVMError
from tvm.relay import create_executor, transform
from tvm.relay.testing import check_grad, run_infer_type
from utils import ref_funcs


def test_zeros_ones():
Expand Down Expand Up @@ -1266,26 +1267,7 @@ def verify_gather_nd(xshape, yshape, y_data, batch_dims=0):
else:
y_data = np.random.randint(low=0, high=2, size=yshape, dtype="int32")

def gather_nd_batch_dims_1_ref(data, indices):
res = []
for i, row in enumerate(data):
indices_tuple = tuple(indices[:, i]) # the indices for the i-th batch
res.append(row[indices_tuple])
# stack on the batch dim
return np.stack(res, 0)

if batch_dims > 1:
x_data_reshape = np.reshape(x_data, (-1,) + xshape[batch_dims:])
y_data_reshape = np.reshape(y_data, (yshape[0], -1) + yshape[(batch_dims + 1) :])

ref_res = gather_nd_batch_dims_1_ref(x_data_reshape, y_data_reshape)

out_shape = yshape[1 : (batch_dims + 1)] + ref_res.shape[1:]
ref_res = np.reshape(ref_res, out_shape)
elif batch_dims == 1:
ref_res = gather_nd_batch_dims_1_ref(x_data, y_data)
else:
ref_res = x_data[tuple(y_data)]
ref_res = ref_funcs.gather_nd(x_data, y_data, batch_dims)

for target, dev in tvm.testing.enabled_targets():
for kind in ["graph", "debug"]:
Expand Down
48 changes: 48 additions & 0 deletions tests/python/relay/utils/ref_funcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np


def gather_nd(data_np, indices_np, batch_dims=0):
"""gather_nd implemented using numpy"""
data_shape = data_np.shape
indices_shape = indices_np.shape

def gather_nd_batch_dims_1_ref(data, indices):
res = []
for i, row in enumerate(data):
indices_tuple = tuple(indices[:, i]) # the indices for the i-th batch
res.append(row[indices_tuple])
# stack on the batch dim
return np.stack(res, 0)

if batch_dims > 1:
data_np_reshape = np.reshape(data_np, (-1,) + data_shape[batch_dims:])
indices_np_reshape = np.reshape(
indices_np, (indices_shape[0], -1) + indices_shape[(batch_dims + 1) :]
)

ref_res = gather_nd_batch_dims_1_ref(data_np_reshape, indices_np_reshape)

out_shape = indices_shape[1 : (batch_dims + 1)] + ref_res.shape[1:]
ref_res = np.reshape(ref_res, out_shape)
elif batch_dims == 1:
ref_res = gather_nd_batch_dims_1_ref(data_np, indices_np)
else:
ref_res = data_np[tuple(indices_np)]

return ref_res

0 comments on commit 533854a

Please sign in to comment.