From 36a4501a151070760559f6ce4cfa574202b4d0c8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 21 May 2021 14:36:34 +0900 Subject: [PATCH] add gather_nd shape func --- include/tvm/relay/attrs/transform.h | 7 +++++++ python/tvm/relay/frontend/onnx.py | 4 +++- python/tvm/relay/op/_transform.py | 31 +++++++++++++++++++++++++++++ python/tvm/relay/op/transform.py | 8 ++++++-- python/tvm/topi/scatter.py | 2 ++ src/relay/op/tensor/transform.cc | 3 ++- 6 files changed, 51 insertions(+), 4 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index d091342a5e4a..341860146743 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -146,11 +146,18 @@ struct GatherAttrs : public tvm::AttrsNode { struct GatherNDAttrs : public tvm::AttrsNode { Integer batch_dims; + Integer gather_dim; TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherNDAttrs") { TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dimensions."); + TVM_ATTR_FIELD(gather_dim) + .set_default(Integer(-1)) + .describe( + "The size of an indexing tuple, which is a fixed value. Only needed when the number of " + "indexting tuples is dynamic."); } }; + struct TakeAttrs : public tvm::AttrsNode { Integer batch_dims; Integer axis; diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 33e5340654cc..4e77f38184d0 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1434,8 +1434,10 @@ class GatherND(OnnxOpConverter): @classmethod def _impl_common(cls, data, indices, batch_dims=0): indices_dims = len(infer_shape(indices)) + indices_shape = infer_shape(indices) indices = _op.transpose(indices, axes=[-1] + list(range(indices_dims - 1))) - return _op.gather_nd(data, indices, batch_dims) + gather_dim = indices_shape[-1] + return _op.gather_nd(data, indices, batch_dims, gather_dim) @classmethod def _impl_v1(cls, inputs, attr, params): diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 37b384dcfc31..ecd63f60f2b7 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1127,3 +1127,34 @@ def unique_shape_func(attrs, inputs, _): return _unique_with_counts_shape(inputs[0]) else: return _unique_shape(inputs[0]) + + +@script +def _gather_nd_shape(data_shape, indices_shape, batch_dims, gather_dim): + ndim = data_shape.shape[0] + mdim = gather_dim + # using mdim = indices_shape[0] wouldn't work because a rank cannot + # depend on a runtime shape dimension of indices tensor, even if the + # dimension is always a known, fixed value. As a workaround, we assume that + # the fixed gather dimension (the size of an indexing tuple) is recorded + # in `gather_nd` op attribute. + err_msg = "The recorded gather dimension and the actual dimension are different" + assert mdim == indices_shape[0], err_msg + kdim = indices_shape.shape[0] - 1 + out_shape = output_tensor((kdim + ndim - (mdim + batch_dims),), "int64") + for i in range(1, kdim + 1): + out_shape[i-1] = indices_shape[i] + for i in range(mdim + batch_dims, ndim): + out_shape[kdim + i - (mdim + batch_dims)] = data_shape[i] + return out_shape + + +@_reg.register_shape_func("gather_nd", False) +def gather_nd_shape_func(attrs, inputs, _): + """ + Shape func for ghater_nd operator. + """ + batch_dims = get_const_int(attrs.batch_dimss) + gather_dim = get_const_int(attrs.gather_dim) + assert gather_dim > 0, "gather_dim needs to be specified for dynamic gather_nd" + return [_gather_nd_shape(inputs[0], inputs[1], convert(batch_dims), convert(gather_dim))] diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index af605b486bdf..ed0c66fe5c3f 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1075,7 +1075,7 @@ def gather(data, axis, indices): return _make.gather(data, axis, indices) -def gather_nd(data, indices, batch_dims=0): +def gather_nd(data, indices, batch_dims=0, gather_dim=-1): """Gather elements or slices from data and store to a tensor whose shape is defined by indices. @@ -1090,6 +1090,10 @@ def gather_nd(data, indices, batch_dims=0): batch_dims : int The number of batch dimensions. + gather_dim : int + The size of an indexing tuple, which is a fixed value and the same as indices.shape[0] + Only needed when other dimensions of indices are dynamic. + Returns ------- ret : relay.Expr @@ -1111,7 +1115,7 @@ def gather_nd(data, indices, batch_dims=0): indices = [[1, 0]] relay.gather_nd(data, indices, batch_dims=1) = [[2,3],[4,5]] """ - return _make.gather_nd(data, indices, batch_dims) + return _make.gather_nd(data, indices, batch_dims, gather_dim) def sequence_mask(data, valid_length, mask_value=0, axis=0): diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index d7b008c4c33f..d11c835cfe99 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -200,6 +200,8 @@ def scatter(data, indices, updates, axis=0): def _verify_scatter_nd_inputs(data, indices, updates): + # TODO(masahi): revisit + return mdim = int(indices.shape[0]) assert mdim <= len(data.shape), ( f"The first dimension of the indices ({mdim}) must be less than or equal to " diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 8c012ecb47d3..98347b8e2cb9 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -3600,10 +3600,11 @@ Array GatherNDCompute(const Attrs& attrs, const Array& i return {topi::gather_nd(inputs[0], inputs[1], param->batch_dims)}; } -Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0) { +Expr MakeGatherND(Expr data, Expr indices, int batch_dims = 0, int gather_dim = -1) { static const Op& op = Op::Get("gather_nd"); auto attrs = make_object(); attrs->batch_dims = batch_dims; + attrs->gather_dim = gather_dim; return Call(op, {data, indices}, Attrs(attrs)); }