Skip to content

Commit

Permalink
SparseReshape Op (apache#7477)
Browse files Browse the repository at this point in the history
* SparseReshape Inital Code

* Done

* Format

* Add empty tests

* Formatting

* SanityCheck

* formatting documentation

* Documentation

* Only Enable CPU

* Add support for CUDA

* Stuff

* Add Dynamic Support

* Parallelize GPU Impl

* Documentation

* Documentation

* Import

* Import

* Remove unnecessary code

* PR Comments

* Schedules

* Tests

* Dtypes

* Black

* Parallelize CPU

* CI error

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
2 people authored and Lokiiiiii committed Mar 1, 2021
1 parent 724b8be commit 7770c1b
Show file tree
Hide file tree
Showing 12 changed files with 916 additions and 0 deletions.
10 changes: 10 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,15 @@ def _impl(inputs, attr, params, mod):
return _impl


def _sparse_reshape():
def _impl(inputs, attr, params, mod):
assert len(inputs) == 3, "There should be 3 input tensors"
new_indices, new_shape = get_relay_op("sparse_reshape")(inputs[0], inputs[1], inputs[2])
return _expr.TupleWrapper(_expr.Tuple([new_indices, new_shape]), 2)

return _impl


def _identity():
def _impl(inputs, attr, params, mod):
return inputs[0]
Expand Down Expand Up @@ -2650,6 +2659,7 @@ def _impl(inputs, attr, params, mod):
"SparseToDense": _sparse_to_dense(),
"SparseTensorDenseMatMul": _sparse_tensor_dense_matmul(),
"SparseFillEmptyRows": _sparse_fill_empty_rows(),
"SparseReshape": _sparse_reshape(),
"Split": _split(False),
"SplitV": _split(True),
"Sqrt": AttrCvt("sqrt"),
Expand Down
35 changes: 35 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
_reg.register_injective_schedule("matrix_set_diag")
_reg.register_injective_schedule("adv_index")


# concatenate
_reg.register_schedule("concatenate", strategy.schedule_concatenate)

Expand Down Expand Up @@ -114,6 +115,22 @@ def compute_sparse_fill_empty_rows(attrs, inputs, output_type):

_reg.register_strategy("sparse_fill_empty_rows", strategy.sparse_fill_empty_rows_strategy)

# sparse_reshape
@_reg.register_compute("sparse_reshape")
def compute_reshape(attrs, inputs, output_type):
"""Compute definition of sparse_reshape"""

return topi.sparse_reshape(
inputs[0],
inputs[1],
inputs[2],
output_type.fields[0].shape,
output_type.fields[1].shape,
)


_reg.register_strategy("sparse_reshape", strategy.sparse_reshape_strategy)

# scatter_add
@_reg.register_compute("scatter_add")
def compute_scatter_add(attrs, inputs, output_type):
Expand Down Expand Up @@ -526,6 +543,24 @@ def sparse_fill_empty_rows_func(attrs, inputs, _):
return _sparse_fill_empty_rows_shape_func(inputs[0], inputs[2])


@script
def _sparse_reshape_shape_func(sparse_indices_shape, prev_shape_shape, new_shape_shape):
indices_shape = output_tensor((2,), "int64")
indices_shape[0] = int64(sparse_indices_shape[0])
indices_shape[1] = int64(new_shape_shape[0])
shape_tensor = output_tensor((1,), "int64")
shape_tensor[0] = int64(new_shape_shape[0])
return (indices_shape, shape_tensor)


@_reg.register_shape_func("sparse_reshape", False)
def sparse_reshape_shape_func(attrs, inputs, _):
"""
Shape func for sparse_reshape.
"""
return _sparse_reshape_shape_func(inputs[0], inputs[1], inputs[2])


@script
def _layout_transform_shape_func(
data_shape, out_layout_len, dst_equal_list, dst_mul_list, dst_div_list, dst_mix_list
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,17 @@ def sparse_dense_strategy_cuda(attrs, inputs, out_type, target):
return strategy


@sparse_reshape_strategy.register(["cuda", "gpu"])
def sparse_reshape_strategy_cuda(attrs, inputs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_sparse_reshape(topi.cuda.sparse_reshape),
wrap_topi_schedule(topi.generic.schedule_extern),
name="sparse_reshape.cuda",
)
return strategy


@sparse_dense_padded_strategy.register(["cuda", "gpu"])
def sparse_dense_padded_strategy_cuda(attrs, inputs, out_type, target):
"""sparse dense cuda strategy"""
Expand Down
27 changes: 27 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,33 @@ def _compute_sparse_fill_empty_rows(attrs, inputs, output_type):
return _compute_sparse_fill_empty_rows


# sparse_reshape
@override_native_generic_func("sparse_reshape_strategy")
def sparse_reshape_strategy(attrs, outs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_sparse_reshape(topi.sparse_reshape),
wrap_topi_schedule(topi.generic.schedule_extern),
name="sparse_reshape.generic",
)
return strategy


def wrap_compute_sparse_reshape(topi_compute):
"""Wrap sparse_reshape compute"""

def _compute_sparse_reshape(attrs, inputs, output_type):
return topi_compute(
inputs[0],
inputs[1],
inputs[2],
output_type.fields[0].shape,
output_type.fields[1].shape,
)

return _compute_sparse_reshape


# roi_pool
@generic_func
def schedule_roi_pool(attrs, outs, target):
Expand Down
40 changes: 40 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,6 +1449,46 @@ def sparse_fill_empty_rows(sparse_indices, sparse_values, dense_shape, default_v
return Tuple((new_sparse_indices, new_sparse_values, empty_row_indicator))


def sparse_reshape(sparse_indices, prev_shape, new_shape):
"""
Reshape a Sparse Tensor. The sparse array is in COO format.
Parameters
----------
sparse_indices : relay.Expr
A 2-D tensor[N, n_dim] of integers containing location of sparse values, where N is the
number of sparse values and n_dim is the number of dimensions of the dense_shape
prev_shape : relay.Expr
A 1-D tensor containing the previous shape of the dense tensor
new_shape : relay.Expr
A 1-D tensor containing the new shape of the dense tensor
Returns
-------
result: relay.Expr
Output tensor.
Examples
--------
.. code-block:: python
sparse_indices = [[0, 0, 0],
[0, 0, 1],
[0, 1, 0],
[1, 0, 0],
[1, 2, 3]]
prev_shape = [2, 3, 4]
new_shape = [9, -1]
new_sparse_indices, new_shape = relay.sparse_reshape(sparse_indices,
prev_shape,
new_shape)
new_sparse_indices = [[0, 0],
[0, 1],
[1, 2],
[4, 2],
[8, 1]]
new_shape = [9, 4]
"""
return TupleWrapper(_make.sparse_reshape(sparse_indices, prev_shape, new_shape), 2)


def cumsum(data, axis=None, dtype=None, exclusive=None):
"""Numpy style cumsum op. Return the cumulative inclusive sum of the elements along
a given axis.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .sort import *
from .scatter import *
from .sparse_fill_empty_rows import *
from .sparse_reshape import *
from .scatter_add import *
from .argwhere import *
from .interpolate import *
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,5 @@
from . import tensorcore_alter_op
from .argwhere import *
from .scan import *
from .sparse_reshape import *
from .unique import *
Loading

0 comments on commit 7770c1b

Please sign in to comment.