diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index b34e6c723645..8529205d96d4 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -998,6 +998,28 @@ def _impl(inputs, attr, params, mod): return _impl +def _sparse_fill_empty_rows(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 4, "There should be 4 input tensors" + sparse_indices = inputs[0] + sparse_values = inputs[1] + sparse_indices_num_cols = _infer_shape(sparse_indices, mod)[1] + first_column = _op.split(sparse_indices, sparse_indices_num_cols, axis=1)[0] + sorted_indices = _op.argsort(_op.squeeze(first_column)) + sorted_sparse_indices = _op.take(sparse_indices, sorted_indices, axis=0) + sorted_sparse_values = _op.take(sparse_values, sorted_indices, axis=0) + new_sparse_indices, new_sparse_values, empty_row_indicator = _op.sparse_fill_empty_rows( + sorted_sparse_indices, sorted_sparse_values, inputs[2], inputs[3] + ) + + return _expr.TupleWrapper( + _expr.Tuple([new_sparse_indices, new_sparse_values, empty_row_indicator]), + 3, + ) + + return _impl + + def _identity(): def _impl(inputs, attr, params, mod): return inputs[0] @@ -2447,6 +2469,7 @@ def _impl(inputs, attr, params, mod): "SpaceToDepth": _space_to_depth(), "SparseToDense": _sparse_to_dense(), "SparseTensorDenseMatMul": _sparse_tensor_dense_matmul(), + "SparseFillEmptyRows": _sparse_fill_empty_rows(), "Split": _split(False), "SplitV": _split(True), "Sqrt": AttrCvt("sqrt"), diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index ba2416ff8950..01bcf4a6cf60 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -15,7 +15,9 @@ # specific language governing permissions and limitations # under the License. """Backend compiler related feature registration""" -# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks, too-many-local-variables, too-many-arguments +# pylint: disable=invalid-name,unused-argument, len-as-condition, too-many-nested-blocks, +# pylint: disable=too-many-local-variables, too-many-arguments, no-else-return + from __future__ import absolute_import import tvm from tvm import te @@ -94,6 +96,24 @@ def compute_scatter(attrs, inputs, output_type): _reg.register_strategy("scatter", strategy.scatter_strategy) +# sparse_fill_empty_rows +@_reg.register_compute("sparse_fill_empty_rows") +def compute_sparse_fill_empty_rows(attrs, inputs, output_type): + """Compute definition of sparse_fill_empty_rows""" + + return topi.sparse_fill_empty_rows( + inputs[0], + inputs[1], + inputs[2], + inputs[3], + output_type.fields[0].shape, + output_type.fields[1].shape, + output_type.fields[2].shape, + ) + + +_reg.register_strategy("sparse_fill_empty_rows", strategy.sparse_fill_empty_rows_strategy) + # scatter_add @_reg.register_compute("scatter_add") def compute_scatter_add(attrs, inputs, output_type): @@ -445,6 +465,47 @@ def argwhere_shape_func(attrs, inputs, out_ndims): _reg.register_shape_func("scatter_add", False, elemwise_shape_func) +@script +def _sparse_fill_empty_rows_shape_func(sparse_indices, dense_shape): + + new_sparse_indices_shape = output_tensor((2,), "int64") + new_sparse_values_shape = output_tensor((1,), "int64") + empty_row_indicator_shape = output_tensor((1,), "int64") + num_dense_rows = int64(dense_shape[0]) + + if int64(sparse_indices.shape[0]) == int64(0): # Handle Empty Case + # Total rows will equal dense_shape[0] + new_sparse_indices_shape[0] = num_dense_rows + new_sparse_indices_shape[1] = int64(sparse_indices.shape[1]) + new_sparse_values_shape[0] = num_dense_rows + empty_row_indicator_shape[0] = num_dense_rows + return (new_sparse_indices_shape, new_sparse_values_shape, empty_row_indicator_shape) + + else: + count = int64(sparse_indices.shape[0]) # Add count of all rows already in sparse_indices + for i in range(1, int64(sparse_indices.shape[0])): + index = int64(sparse_indices[i, 0]) + prev_index = int64(sparse_indices[i - 1, 0] + 1) + + if index > prev_index: + count += index - prev_index # Add count of all rows between two consecutive indices + + count += int64(sparse_indices[0, 0]) # Add count from 0 to first row id in sparse_indices + count += int64( + num_dense_rows - 1 - sparse_indices[sparse_indices.shape[0] - 1, 0] + ) # Add count from last row id to dense_shape - 1 + new_sparse_indices_shape[0] = int64(count) + new_sparse_indices_shape[1] = int64(sparse_indices.shape[1]) + new_sparse_values_shape[0] = int64(count) + empty_row_indicator_shape[0] = num_dense_rows + return (new_sparse_indices_shape, new_sparse_values_shape, empty_row_indicator_shape) + + +@_reg.register_shape_func("sparse_fill_empty_rows", True) +def sparse_fill_empty_rows_func(attrs, inputs, _): + return _sparse_fill_empty_rows_shape_func(inputs[0], inputs[2]) + + @script def _layout_transform_shape_func( data_shape, out_layout_len, dst_equal_list, dst_mul_list, dst_div_list, dst_mix_list diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index af1d2552fab7..d4b1d15a5318 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -1055,6 +1055,35 @@ def roi_align_strategy(attrs, inputs, out_type, target): return strategy +# sparse_fill_empty_rows +@override_native_generic_func("sparse_fill_empty_rows_strategy") +def sparse_fill_empty_rows_strategy(attrs, outs, out_type, target): + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_sparse_fill_empty_rows(topi.sparse_fill_empty_rows), + wrap_topi_schedule(topi.generic.schedule_sparse_fill_empty_rows), + name="sparse_fill_empty_rows.generic", + ) + return strategy + + +def wrap_compute_sparse_fill_empty_rows(topi_compute): + """Wrap sparse_fill_empty_rows compute""" + + def _compute_sparse_fill_empty_rows(attrs, inputs, output_type): + return topi_compute( + inputs[0], + inputs[1], + inputs[2], + inputs[3], + output_type.fields[0].shape, + output_type.fields[1].shape, + output_type.fields[2].shape, + ) + + return _compute_sparse_fill_empty_rows + + # roi_pool @generic_func def schedule_roi_pool(attrs, outs, target): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index e9d081eb5fb6..f8eed3516f91 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1322,6 +1322,78 @@ def adv_index(inputs): return _make.adv_index(Tuple(inputs)) +def sparse_fill_empty_rows(sparse_indices, sparse_values, dense_shape, default_value): + """ + Fill rows in a sparse matrix that do no contain any values. Values are placed in the first + column of empty rows. The sparse array is in COO format. + It returns a TupleWrapper with 3 outputs + Parameters + ---------- + sparse_indices : relay.Expr + A 2-D tensor[N, ndims] of integers containing location of sparse values, where N is + the number of sparse values and n_dim is the number of dimensions of the dense_shape. + The first column of this relay parameter must be sorted in ascending order. + sparse_values : relay.Expr + A 1-D tensor[N] containing the sparse values for the sparse indices. + dense_shape : relay.Expr + A 1-D tensor[ndims] which contains shape of the dense output tensor. + default_value : relay.Expr + A 1-D tensor[1] containing the default value for the remaining locations. + Returns + ------- + new_sparse_indices : relay.Expr + A 2-D tensor[?, ndims] of integers containing location of new sparse + indices. The first column outputs must be sorted in ascending order. + new_sparse_values : relay.Expr + A 1-D tensor[?] containing the sparse values for the sparse indices. + empty_row_indicator : relay.Expr + A 1-D tensor[dense_shape[0]] filled with zeros and ones + indicating whether the particular row is empty or full respectively + + Note + ---- + This op exactly follows the documentation here: + https://www.tensorflow.org/api_docs/python/tf/sparse/fill_empty_rows + There are two exceptions: + 1. Input Sparse Indices are expected to be in row-major order. + 2. Empty Row Indicator has int64 output type with 1(for True) and 0(for False). + + Examples + ------- + .. code-block:: python + sparse_indices = [[0, 1], + [0, 3], + [2, 0], + [3, 1]] + sparse_values = [1, 2, 3, 4] + default_value = [10] + dense_shape = [5, 6] + new_sparse_indices, empty_row_indicator, new_sparse_values, slice_element_index = + relay.sparse_fill_empty_rows( + sparse_indices, + sparse_values, + default_value, + dense_shape) + new_sparse_indices = [[0, 1], + [0, 3], + [1, 0], + [2, 0], + [3, 1], + [4, 0]] + empty_row_indicator = [False, True, False, False, True] + new_sparse_values = [1, 2, 10, 3, 4, 10] + + """ + new_sparse_indices, new_sparse_values, empty_row_indicator = TupleWrapper( + _make.sparse_fill_empty_rows(sparse_indices, sparse_values, dense_shape, default_value), 3 + ) + new_sparse_indices = cast_like(new_sparse_indices, sparse_indices) + new_sparse_values = cast_like(new_sparse_values, sparse_values) + empty_row_indicator = cast(empty_row_indicator, "bool") + + return Tuple((new_sparse_indices, new_sparse_values, empty_row_indicator)) + + def cumsum(data, axis=None, dtype=None, exclusive=None): """Numpy style cumsum op. Return the cumulative inclusive sum of the elements along a given axis. diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py index 6836f04b5ada..2b17162048e0 100644 --- a/python/tvm/topi/__init__.py +++ b/python/tvm/topi/__init__.py @@ -38,6 +38,7 @@ from .broadcast import * from .sort import * from .scatter import * +from .sparse_fill_empty_rows import * from .scatter_add import * from .argwhere import * from .cumsum import * diff --git a/python/tvm/topi/generic/search.py b/python/tvm/topi/generic/search.py index b3c8772046fd..5924d35def73 100644 --- a/python/tvm/topi/generic/search.py +++ b/python/tvm/topi/generic/search.py @@ -66,3 +66,7 @@ def schedule_scatter_add(outs): The computation schedule for the op. """ return _default_schedule(outs, False) + + +def schedule_sparse_fill_empty_rows(outs): + return _default_schedule(outs, False) diff --git a/python/tvm/topi/sparse_fill_empty_rows.py b/python/tvm/topi/sparse_fill_empty_rows.py new file mode 100644 index 000000000000..10dc6ee3bfa3 --- /dev/null +++ b/python/tvm/topi/sparse_fill_empty_rows.py @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHnew_sparse_indices WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=no-else-return, too-many-locals, too-many-arguments, too-many-branches +# pylint: disable=undefined-variable, invalid-name +"""SparseFillEmptyRows operator""" +from ..te import hybrid + + +@hybrid.script +def _sparse_fill_empty_rows( + sparse_indices, + sparse_values, + dense_shape, + default_value, + new_sparse_indices_shape, + new_sparse_values_shape, + empty_row_indicator_shape, +): + default_value_ = int64(default_value[0]) + new_sparse_indices = output_tensor(new_sparse_indices_shape, "int64") + new_sparse_values = output_tensor(new_sparse_values_shape, "int64") + empty_row_indicator = output_tensor(empty_row_indicator_shape, "int64") + new_sparse_indices_row_id = 0 + + if int64(sparse_indices.shape[0]) == int64(0): # Handle Empty Case + # Fill all rows with default values + for i in range(0, new_sparse_indices_shape[0]): + new_sparse_indices[i, 0] = int64(i) + new_sparse_values[i] = default_value_ + empty_row_indicator[i] = int64(1) + for k in range(1, int64(new_sparse_indices_shape[1])): + new_sparse_indices[i, k] = int64(0) + + return (new_sparse_indices, new_sparse_values, empty_row_indicator) + + else: + # Iterate through sparse_indices and add rows if/when required + for i in range(0, int64(sparse_indices.shape[0])): + if i == 0: + prev_row_id = int64(0) + else: + prev_row_id = int64(sparse_indices[i - 1, 0] + 1) + row_id = int64(sparse_indices[i, 0]) + + # Since input is in row-major order, add rows between prev_row_id and row_id + for j in range(prev_row_id, row_id): + new_sparse_indices[new_sparse_indices_row_id, 0] = int64(j) + for k in range(1, int64(new_sparse_indices_shape[1])): + new_sparse_indices[new_sparse_indices_row_id, k] = int64(0) + empty_row_indicator[prev_row_id] = int64(1) + new_sparse_values[new_sparse_indices_row_id] = default_value_ + new_sparse_indices_row_id += 1 + + # Add current element to output + new_sparse_indices[new_sparse_indices_row_id, 0] = row_id + for k in range(1, int64(new_sparse_indices_shape[1])): + new_sparse_indices[new_sparse_indices_row_id, k] = int64(sparse_indices[i, k]) + new_sparse_values[new_sparse_indices_row_id] = int64(sparse_values[i]) + empty_row_indicator[row_id] = int64(0) + new_sparse_indices_row_id += 1 + + # Add rows with default value if last row id of sparse_indices is not dense_shape[0] - 1 + for i in range( + int64(sparse_indices[sparse_indices.shape[0] - 1, 0] + 1), int64(dense_shape[0]) + ): + + new_sparse_indices[new_sparse_indices_row_id, 0] = int64(i) + for k in range(1, int64(new_sparse_indices_shape[1])): + new_sparse_indices[new_sparse_indices_row_id, k] = int64(0) + empty_row_indicator[i] = int64(1) + new_sparse_values[new_sparse_indices_row_id] = default_value_ + new_sparse_indices_row_id += 1 + + return (new_sparse_indices, new_sparse_values, empty_row_indicator) + + +def sparse_fill_empty_rows( + sparse_indices, + sparse_values, + dense_shape, + default_value, + new_sparse_indices_shape, + new_sparse_values_shape, + empty_row_indicator_shape, +): + return _sparse_fill_empty_rows( + sparse_indices, + sparse_values, + dense_shape, + default_value, + new_sparse_indices_shape, + new_sparse_values_shape, + empty_row_indicator_shape, + ) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 5e39b409615d..35954746ab1f 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1584,6 +1584,50 @@ RELAY_REGISTER_OP("repeat") .set_attr("FTVMCompute", RepeatCompute) .set_attr("TOpPattern", kBroadcast); +bool SparseFillEmptyRowsRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // types: [sparse_indices, sparse_values, dense_shape, default_value, result] + ICHECK_EQ(types.size(), 5) << "SparseFillEmptyRowsRel expects 5 inputs but " << types.size() + << "provided"; + std::vector fields; + auto sparse_indices = types[0].as(); + auto ndims = sparse_indices->shape[1]; + fields.push_back(TensorType(Array{Any(), ndims}, tvm::DataType::Int(64))); + fields.push_back(TensorType(Array{Any()}, tvm::DataType::Int(64))); + fields.push_back(TensorType(Array{Any()}, tvm::DataType::Int(64))); + reporter->Assign(types[types.size() - 1], TupleType(Array(fields))); + return true; +} + +Expr MakeSparseFillEmptyRows(Expr sparse_indices, Expr sparse_values, Expr dense_shape, + Expr default_value) { + static const Op& op = Op::Get("sparse_fill_empty_rows"); + return Call(op, {sparse_indices, sparse_values, dense_shape, default_value}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.sparse_fill_empty_rows") + .set_body_typed(MakeSparseFillEmptyRows); + +RELAY_REGISTER_OP("sparse_fill_empty_rows") + .describe( + R"code(Fill empty rows of a sparse tensor with a default value.)code" TVM_ADD_FILELINE) + .set_num_inputs(4) + .add_argument("sparse_indices", "Tensor", + "A 2-D int64 tensor of shape [N, ndims], which specifies the indices of the" + "elements in the sparse tensor that contain nonzero values. COO Format") + .add_argument( + "sparse_values", "Tensor", + "A 1-D tensor[N] which supplies the values for each element in indices. COO Format") + .add_argument("dense_shape", "Tensor", + "A 1-D int64 tensor of shape [ndims], which specifies the dense_shape of the" + "sparse tensor. Takes a list indicating the number of elements in each " + "dimension") + .add_argument("default_value", "Tensor", + "The value to fill for empty rows, with the same type as sparse_values") + .add_type_rel("sparse_fill_empty_rows", SparseFillEmptyRowsRel) + .set_support_level(3) + .set_attr("TOpPattern", kOpaque); + // meshgrid operator TVM_REGISTER_NODE_TYPE(MeshgridAttrs); diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 34ee0f3528ae..7ecdb7882502 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1812,6 +1812,109 @@ def test_forward_sparse_dense_matmul(): ) +####################################################################### +# SparseFillEmptyRows +# ------------ + + +def _test_sparse_fill_empty_rows(indices_np, values_np, dense_shape_np, default_value_int, use_dyn): + with tf.Graph().as_default(): + if use_dyn: + indices = tf.placeholder(shape=(None, None), dtype=indices_np.dtype, name="indices") + values = tf.placeholder(shape=(None), dtype=values_np.dtype, name="values") + dense_shape = tf.placeholder( + shape=(None), dtype=dense_shape_np.dtype, name="dense_shape" + ) + else: + indices = tf.placeholder(shape=indices_np.shape, dtype=indices_np.dtype, name="indices") + values = tf.placeholder(shape=values_np.shape, dtype=values_np.dtype, name="values") + dense_shape = tf.placeholder( + shape=dense_shape_np.shape, dtype=dense_shape_np.dtype, name="dense_shape" + ) + + default_value = tf.placeholder(shape=(), dtype=values_np.dtype, name="default_value") + sp_input = tf.sparse.SparseTensor(indices=indices, values=values, dense_shape=dense_shape) + _ = tf.sparse.fill_empty_rows(sp_input, default_value, name="sparse_fill_empty_rows") + compare_tf_with_tvm( + [indices_np, values_np, dense_shape_np, default_value_int], + [indices.name, values.name, dense_shape.name, default_value.name], + [ + "sparse_fill_empty_rows/SparseFillEmptyRows:0", + "sparse_fill_empty_rows/SparseFillEmptyRows:1", + "sparse_fill_empty_rows/SparseFillEmptyRows:2", + ], + mode="vm", + ) + + +@pytest.mark.parametrize( + "sparse_indices_np, sparse_values_np, dense_shape_np, default_value_int", + [ + ( + np.array([[1, 1], [0, 3], [0, 1], [2, 0], [3, 1]], dtype=np.int64), + np.array([1, 2, 3, 4, 5], dtype=np.int64), + np.array([5, 6], dtype=np.int64), + 10, + ), + ( + np.array([[1, 1], [0, 3], [2, 0], [3, 1]], dtype=np.int64), + np.array([1, 2, 3, 4], dtype=np.int64), + np.array([5, 6], dtype=np.int64), + 10, + ), + ( + np.array([[0, 1], [0, 3], [2, 0], [3, 1]], dtype=np.int64), + np.array([1, 2, 3, 4], dtype=np.int64), + np.array([5, 6], dtype=np.int64), + 10, + ), + ( + np.array([[1, 1, 1], [1, 3, 1], [2, 0, 5], [3, 1, 6]], dtype=np.int64), + np.array([1, 2, 3, 4], dtype=np.int64), + np.array([7, 7, 7], dtype=np.int64), + 5, + ), + ( + np.array([[1], [2]], dtype=np.int64), + np.array([7, 8], dtype=np.int64), + np.array([5], dtype=np.int64), + 4, + ), + ( + np.ones((0, 1), dtype=np.int64), + np.array([], dtype=np.int64), + np.array([5], dtype=np.int64), + 4, + ), + ( + np.ones((0, 3), dtype=np.int64), + np.array([], dtype=np.int64), + np.array([9, 3, 7], dtype=np.int64), + 100, + ), + ], +) +@pytest.mark.parametrize("use_dyn", [True, False]) +def test_forward_sparse_fill_empty_rows( + sparse_indices_np, sparse_values_np, dense_shape_np, default_value_int, use_dyn +): + """ sparse_fill_empty_rows op test""" + ################################################################### + # + # In order to create a SparseTensor, it requires 3 input as below: + # SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) + # + # Above Sparse can be represented in Dense as below : + # [[1, 0, 0, 0] + # [0, 0, 2, 0] + # [0, 0, 0, 0]] + # + # ------------------------------------------------------------------ + _test_sparse_fill_empty_rows( + sparse_indices_np, sparse_values_np, dense_shape_np, default_value_int, use_dyn + ) + + ####################################################################### # StridedSlice # ------------ diff --git a/tests/python/relay/dyn/test_dynamic_op_level3.py b/tests/python/relay/dyn/test_dynamic_op_level3.py index dd73b9a96a52..d5f81e84e39d 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level3.py +++ b/tests/python/relay/dyn/test_dynamic_op_level3.py @@ -26,14 +26,21 @@ import tvm.testing -def verify_func(func, data, ref_res): +def verify_func(func, data, ref_res, target_ctx=tvm.testing.enabled_targets()): assert isinstance(data, list) - for target, ctx in tvm.testing.enabled_targets(): + for target, ctx in target_ctx: for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) op_res = intrp.evaluate()(*data) - tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + if isinstance(op_res, tvm.runtime.container.ADT): + assert len(op_res) == len( + ref_res + ), "Outputs from TVM and Python implementation must be equal " + for op_result, ref_result in zip(op_res, ref_res): + tvm.testing.assert_allclose(op_result.asnumpy(), ref_result, rtol=1e-5) + else: + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) relay.backend.compile_engine.get().clear() @@ -202,5 +209,160 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0]) # default value not specified +@pytest.mark.parametrize( + "sparse_indices, sparse_values, dense_shape, default_value", + [ + ( + np.array([[0, 1], [0, 3], [2, 0], [3, 1]], dtype=np.int64), + np.array([1, 2, 3, 4], dtype=np.int64), + np.array([5, 6], dtype=np.int64), + np.array([10], dtype=np.int64), + ), + ( + np.array([[1, 1, 1], [1, 3, 1], [2, 0, 5], [3, 1, 6]], dtype=np.int64), + np.array([1, 2, 3, 4], dtype=np.int64), + np.array([7, 7, 7], dtype=np.int64), + np.array([5], dtype=np.int64), + ), + ( + np.array([[1], [2]], dtype=np.int64), + np.array([7, 8], dtype=np.int64), + np.array([5], dtype=np.int64), + np.array([4], dtype=np.int64), + ), + ( + np.ones((0, 1), dtype=np.int64), + np.array([], dtype=np.int64), + np.array([5], dtype=np.int64), + np.array([4], dtype=np.int64), + ), + ( + np.ones((0, 3), dtype=np.int64), + np.array([], dtype=np.int64), + np.array([9, 3, 7], dtype=np.int64), + np.array([100], dtype=np.int64), + ), + ], +) +@pytest.mark.parametrize("dtype", [np.int64, np.int32]) +@pytest.mark.parametrize("use_dyn", [True, False]) +def test_sparse_fill_empty_rows( + sparse_indices, sparse_values, dense_shape, default_value, dtype, use_dyn +): + def ref_sparse_fill_empty_rows( + sparse_indices: np.ndarray, + sparse_values: np.ndarray, + dense_shape: np.ndarray, + default_value: np.ndarray, + ) -> None: + """ + This function calculates the expected output of sparse_fill_empty_rows operator given the + inputs. + """ + + def check_add_rows(current_idx, limit_idx): + while current_idx < limit_idx: + new_sparse_indices.append([current_idx] + [0] * (num_cols - 1)) + new_sparse_values.append(default_value[0]) + empty_row_indicator[current_idx] = True + current_idx += 1 + + return current_idx + + current_idx = 0 + new_sparse_indices = [] + new_sparse_values = [] + empty_row_indicator = [False for _ in range(dense_shape[0])] + num_cols = sparse_indices.shape[1] + for sparse_row, sparse_value in zip(sparse_indices, sparse_values): + limit_idx = sparse_row[0] + current_idx = check_add_rows(current_idx, limit_idx) + new_sparse_indices.append(list(sparse_row)) + new_sparse_values.append(sparse_value) + current_idx = limit_idx + 1 + + check_add_rows(current_idx, dense_shape[0]) + return new_sparse_indices, new_sparse_values, empty_row_indicator + + def verify_sparse_fill_empty_rows( + sparse_indices_np: np.ndarray, + sparse_values_np: np.ndarray, + dense_shape_np: np.ndarray, + default_value_np: np.ndarray, + ) -> None: + """ + This function verifies the relay output of sparse_fill_empty_rows with its expected output. + """ + if use_dyn: + sparse_indices = relay.var( + "sparse_indices", + shape=[relay.Any(), relay.Any()], + dtype=str(sparse_indices_np.dtype), + ) + sparse_values = relay.var( + "sparse_values", + shape=[relay.Any()], + dtype=str(sparse_values_np.dtype), + ) + dense_shape = relay.var( + "dense_shape", + shape=[relay.Any()], + dtype=str(dense_shape_np.dtype), + ) + default_value = relay.var( + "default_value", + shape=[relay.Any()], + dtype=str(default_value_np.dtype), + ) + else: + sparse_indices = relay.var( + "sparse_indices", + relay.TensorType(sparse_indices_np.shape, str(sparse_indices_np.dtype)), + ) + sparse_values = relay.var( + "sparse_values", + relay.TensorType(sparse_values_np.shape, str(sparse_values_np.dtype)), + ) + dense_shape = relay.var( + "dense_shape", + relay.TensorType(dense_shape_np.shape, str(dense_shape_np.dtype)), + ) + default_value = relay.var( + "default_value", + relay.TensorType(default_value_np.shape, str(default_value_np.dtype)), + ) + z = relay.sparse_fill_empty_rows(sparse_indices, sparse_values, dense_shape, default_value) + func = relay.Function([sparse_indices, sparse_values, dense_shape, default_value], z) + ref_res = ref_sparse_fill_empty_rows( + sparse_indices_np, + sparse_values_np, + dense_shape_np, + default_value_np, + ) + ( + new_sparse_indices_infer_type, + new_sparse_values_infer_type, + empty_row_indicator_infer_type, + ) = run_infer_type(z) + + assert new_sparse_indices_infer_type.checked_type.dtype == sparse_indices_np.dtype + assert new_sparse_values_infer_type.checked_type.dtype == sparse_indices_np.dtype + assert empty_row_indicator_infer_type.checked_type.dtype == "bool" + + verify_func( + func, + [sparse_indices_np, sparse_values_np, dense_shape_np, default_value_np], + ref_res, + [("llvm", tvm.cpu())], + ) + + verify_sparse_fill_empty_rows( + sparse_indices.astype(dtype), + sparse_values.astype(dtype), + dense_shape.astype(dtype), + default_value.astype(dtype), + ) + + if __name__ == "__main__": pytest.main([__file__])