Skip to content

Commit

Permalink
add gather_nd shape func
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 27, 2021
1 parent 1853c35 commit 36a4501
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 4 deletions.
7 changes: 7 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,18 @@ struct GatherAttrs : public tvm::AttrsNode<GatherAttrs> {

struct GatherNDAttrs : public tvm::AttrsNode<GatherNDAttrs> {
Integer batch_dims;
Integer gather_dim;

TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherNDAttrs") {
TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dimensions.");
TVM_ATTR_FIELD(gather_dim)
.set_default(Integer(-1))
.describe(
"The size of an indexing tuple, which is a fixed value. Only needed when the number of "
"indexting tuples is dynamic.");
}
};

struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
Integer batch_dims;
Integer axis;
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,8 +1434,10 @@ class GatherND(OnnxOpConverter):
@classmethod
def _impl_common(cls, data, indices, batch_dims=0):
indices_dims = len(infer_shape(indices))
indices_shape = infer_shape(indices)
indices = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1)))
return _op.gather_nd(data, indices, batch_dims)
gather_dim = indices_shape[-1]
return _op.gather_nd(data, indices, batch_dims, gather_dim)

@classmethod
def _impl_v1(cls, inputs, attr, params):
Expand Down
31 changes: 31 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,3 +1127,34 @@ def unique_shape_func(attrs, inputs, _):
return _unique_with_counts_shape(inputs[0])
else:
return _unique_shape(inputs[0])


@script
def _gather_nd_shape(data_shape, indices_shape, batch_dims, gather_dim):
ndim = data_shape.shape[0]
mdim = gather_dim
# using mdim = indices_shape[0] wouldn't work because a rank cannot
# depend on a runtime shape dimension of indices tensor, even if the
# dimension is always a known, fixed value. As a workaround, we assume that
# the fixed gather dimension (the size of an indexing tuple) is recorded
# in `gather_nd` op attribute.
err_msg = "The recorded gather dimension and the actual dimension are different"
assert mdim == indices_shape[0], err_msg
kdim = indices_shape.shape[0] - 1
out_shape = output_tensor((kdim + ndim - (mdim + batch_dims),), "int64")
for i in range(1, kdim + 1):
out_shape[i-1] = indices_shape[i]
for i in range(mdim + batch_dims, ndim):
out_shape[kdim + i - (mdim + batch_dims)] = data_shape[i]
return out_shape


@_reg.register_shape_func("gather_nd", False)
def gather_nd_shape_func(attrs, inputs, _):
"""
Shape func for ghater_nd operator.
"""
batch_dims = get_const_int(attrs.batch_dimss)
gather_dim = get_const_int(attrs.gather_dim)
assert gather_dim > 0, "gather_dim needs to be specified for dynamic gather_nd"
return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(gather_dim))]
8 changes: 6 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,7 @@ def gather(data, axis, indices):
return _make.gather(data, axis, indices)


def gather_nd(data, indices, batch_dims=0):
def gather_nd(data, indices, batch_dims=0, gather_dim=-1):
"""Gather elements or slices from data and store to a tensor whose shape is
defined by indices.
Expand All @@ -1090,6 +1090,10 @@ def gather_nd(data, indices, batch_dims=0):
batch_dims : int
The number of batch dimensions.
gather_dim : int
The size of an indexing tuple, which is a fixed value and the same as indices.shape[0]
Only needed when other dimensions of indices are dynamic.
Returns
-------
ret : relay.Expr
Expand All @@ -1111,7 +1115,7 @@ def gather_nd(data, indices, batch_dims=0):
indices = [[1, 0]]
relay.gather_nd(data, indices, batch_dims=1) = [[2,3],[4,5]]
"""
return _make.gather_nd(data, indices, batch_dims)
return _make.gather_nd(data, indices, batch_dims, gather_dim)


def sequence_mask(data, valid_length, mask_value=0, axis=0):
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/topi/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ def scatter(data, indices, updates, axis=0):


def _verify_scatter_nd_inputs(data, indices, updates):
# TODO(masahi): revisit
return
mdim = int(indices.shape[0])
assert mdim <= len(data.shape), (
f"The first dimension of the indices ({mdim}) must be less than or equal to "
Expand Down
3 changes: 2 additions & 1 deletion src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3600,10 +3600,11 @@ Array<te::Tensor> GatherNDCompute(const Attrs& attrs, const Array<te::Tensor>& i
return {topi::gather_nd(inputs[0], inputs[1], param->batch_dims)};
}

Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0) {
Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0, int gather_dim = -1) {
static const Op& op = Op::Get("gather_nd");
auto attrs = make_object<GatherNDAttrs>();
attrs->batch_dims = batch_dims;
attrs->gather_dim = gather_dim;
return Call(op, {data, indices}, Attrs(attrs));
}

Expand Down

0 comments on commit 36a4501

Please sign in to comment.