Skip to content

Commit

Permalink
[Topi, Relay] Add cumprod (apache#7722)
Browse files Browse the repository at this point in the history
* make cumbinop, refactor cumsum, add cumprod

* cumsum exclusive test

* Add cumprod + flesh out cumsum tests

add cumprod and tests

reinstate tests

rethink

* add rudimentary scan implementation

* add attributes of cumprod node

* add cumprod strategy

* add cuda strategy

* python relay node construction

* change attrs to be reusuable

* add cumprod nodes

* complete tests

* Fix some typos about sum --> prod

typos fix sum -> prod

more typos

more typo fixes

more typos

add doc strings

* Use Bool instead of int to represent exclusive

make exclusive a bool up and down stack

fix x

fix bool err

it is a bool now

fix

fix thing

formatting to pass linter

lint python

cumprod pylint

fix attribute

fix ordering

add exclusivity tests for end to end

fix things

cuda identity_value

* Overall improve formatting, add doc message corrections

simplify construction

clang-format

more tests

undo simpler construction due to function passing stuff

fix docs

more exclusive doc changes

more fixins"

* merge cumsum and cumprod to scan, merge tests

fix stuff

* remove other mentions of cumbinop -> scanop

* lint formatting

Co-authored-by: Andrew Zhao Luo <[email protected]>
  • Loading branch information
2 people authored and trevor-m committed May 11, 2021
1 parent 5d544e4 commit 178d959
Show file tree
Hide file tree
Showing 14 changed files with 758 additions and 279 deletions.
14 changes: 8 additions & 6 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -446,17 +446,19 @@ struct MatrixSetDiagAttrs : public tvm::AttrsNode<MatrixSetDiagAttrs> {
}
}; // struct MatrixSetDiagAttrs

/*! \brief Attributes used in cumsum operator */
struct CumsumAttrs : public tvm::AttrsNode<CumsumAttrs> {
/*! \brief Attributes used in cumsum and cumprod operator */
struct ScanopAttrs : public tvm::AttrsNode<ScanopAttrs> {
Integer axis;
DataType dtype;
Integer exclusive;
TVM_DECLARE_ATTRS(CumsumAttrs, "relay.attrs.CumsumAttrs") {
TVM_ATTR_FIELD(axis).describe("The axis to sum over").set_default(NullValue<Integer>());
Bool exclusive = Bool(false);
TVM_DECLARE_ATTRS(ScanopAttrs, "relay.attrs.ScanopAttrs") {
TVM_ATTR_FIELD(axis).describe("The axis to operate over").set_default(NullValue<Integer>());
TVM_ATTR_FIELD(dtype).describe("Output data type").set_default(NullValue<DataType>());

// Default is 0 which is "false"
TVM_ATTR_FIELD(exclusive)
.describe("The first element is not included")
.set_default(NullValue<Integer>());
.set_default(Bool(false));
}
};

Expand Down
19 changes: 15 additions & 4 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@
# pylint: disable=too-many-local-variables, too-many-arguments, no-else-return

from __future__ import absolute_import

import tvm
from tvm import te
from tvm.te.hybrid import script
from tvm import te, topi
from tvm.runtime import convert
from tvm import topi
from tvm.te.hybrid import script
from tvm.topi.utils import get_const_int, get_const_tuple

from . import op as _reg
from . import strategy
from .op import OpPattern
from ._tensor import elemwise_shape_func
from .op import OpPattern

_reg.register_broadcast_schedule("broadcast_to")
_reg.register_broadcast_schedule("broadcast_to_like")
Expand Down Expand Up @@ -168,6 +169,16 @@ def compute_cumsum(attrs, inputs, output_type):
_reg.register_strategy("cumsum", strategy.cumsum_strategy)
_reg.register_shape_func("cumsum", False, elemwise_shape_func)

# cumprod
@_reg.register_compute("cumprod")
def compute_cumprod(attrs, inputs, output_type):
"""Compute definition of cumprod"""
return [topi.cumprod(inputs[0], attrs.axis, attrs.dtype, attrs.exclusive)]


_reg.register_strategy("cumprod", strategy.cumprod_strategy)
_reg.register_shape_func("cumprod", False, elemwise_shape_func)


@_reg.register_compute("unique")
def compute_unique(attrs, inputs, output_type):
Expand Down
19 changes: 16 additions & 3 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
from tvm import topi
from tvm.auto_scheduler import is_auto_scheduler_enabled
from tvm.te import SpecializedCondition
from tvm.contrib import nvcc
from tvm.contrib.thrust import can_use_thrust
from .generic import *
from tvm.te import SpecializedCondition

from .. import op as _op
from .generic import *


@schedule_injective.register(["cuda", "gpu"])
Expand Down Expand Up @@ -1017,13 +1018,25 @@ def cumsum_strategy_cuda(attrs, inputs, out_type, target):
"""cumsum cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_cumsum(topi.cuda.cumsum),
wrap_compute_scanop(topi.cuda.cumsum),
wrap_topi_schedule(topi.cuda.schedule_scan),
name="cumsum.cuda",
)
return strategy


@cumprod_strategy.register(["cuda", "gpu"])
def cumprod_strategy_cuda(attrs, inputs, out_type, target):
"""cumprod cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scanop(topi.cuda.cumprod),
wrap_topi_schedule(topi.cuda.schedule_scan),
name="cumprod.cuda",
)
return strategy


@unique_strategy.register(["cuda", "gpu"])
def unique_strategy_cuda(attrs, inputs, out_type, target):
"""unique cuda strategy"""
Expand Down
29 changes: 21 additions & 8 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
"""Definition of generic operator strategy."""
# pylint: disable=invalid-name,unused-argument
import logging

import re
from tvm import topi, _ffi, te, ir
from tvm.topi.utils import get_const_int, get_const_float, get_const_tuple, get_float_tuple

from tvm import _ffi, ir, te, topi
from tvm.target import generic_func, override_native_generic_func
from tvm.topi.utils import get_const_float, get_const_int, get_const_tuple, get_float_tuple

from .. import op as _op

logger = logging.getLogger("strategy")
Expand Down Expand Up @@ -1471,27 +1472,39 @@ def threefry_split_strategy(attrs, inputs, out_type, target):
return strategy


def wrap_compute_cumsum(topi_compute):
"""Wrap cumsum topi compute"""
def wrap_compute_scanop(topi_compute):
"""Wrap scanop style topi compute"""

def _compute_cumsum(attrs, inputs, _):
def _compute_scanop(attrs, inputs, _):
return [topi_compute(inputs[0], attrs.axis, attrs.dtype, attrs.exclusive)]

return _compute_cumsum
return _compute_scanop


@override_native_generic_func("cumsum_strategy")
def cumsum_strategy(attrs, inputs, out_type, target):
"""cumsum generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_cumsum(topi.cumsum),
wrap_compute_scanop(topi.cumsum),
wrap_topi_schedule(topi.generic.schedule_extern),
name="cumsum.generic",
)
return strategy


@override_native_generic_func("cumprod_strategy")
def cumprod_strategy(attrs, inputs, out_type, target):
"""cumprod generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scanop(topi.cumprod),
wrap_topi_schedule(topi.generic.schedule_extern),
name="cumprod.generic",
)
return strategy


def wrap_compute_unique(topi_compute):
"""Wrap unique topi compute"""

Expand Down
65 changes: 60 additions & 5 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
# pylint: disable=import-outside-toplevel, unused-argument, invalid-name
"""Transform operators."""

from ...tir import expr as _expr
from ..expr import Constant, Expr, Tuple, TupleWrapper, const
from . import _make
from .dyn import _make as _dyn_make
from .tensor import shape_of
from ..expr import TupleWrapper, const, Constant, Expr, Tuple
from ...tir import expr as _expr


def cast(data, dtype):
Expand Down Expand Up @@ -1578,9 +1578,9 @@ def cumsum(data, axis=None, dtype=None, exclusive=None):
Type of the returned array and of the accumulator in which the elements are summed.
If dtype is not specified, it defaults to the dtype of data.
exclusive : int, optional
If set to 1 will return exclusive sum in which the first element is not
included. In other terms, if set to 1, the j-th output element would be
exclusive : bool, optional
If true will return exclusive sum in which the first element is not
included. In other terms, if true, the j-th output element would be
the sum of the first (j-1) elements. Otherwise, it would be the sum of
the first j elements.
Expand Down Expand Up @@ -1616,6 +1616,61 @@ def cumsum(data, axis=None, dtype=None, exclusive=None):
return _make.cumsum(data, axis, dtype, exclusive)


def cumprod(data, axis=None, dtype=None, exclusive=None):
"""Numpy style cumprod op. Return the cumulative inclusive product of the elements along
a given axis.
Parameters
----------
data : relay.Expr
The input data to the operator.
axis : int, optional
Axis along which the cumulative product is computed. The default (None) is to compute
the cumprod over the flattened array.
dtype : string, optional
Type of the returned array and of the accumulator in which the elements are multiplied.
If dtype is not specified, it defaults to the dtype of data.
exclusive : bool, optional
If true will return exclusive product in which the first element is not
included. In other terms, if true, the j-th output element would be
the product of the first (j-1) elements. Otherwise, it would be the product of
the first j elements. The product of zero elements will be 1.
Returns
-------
result : relay.Expr
The result has the same size as data, and the same shape as data if axis is not None.
If axis is None, the result is a 1-d array.
Examples
--------
.. code-block:: python
a = [[1,2,3], [4,5,6]]
cumprod(a) # if axis is not provided, cumprod is done over the flattened input.
-> [ 1, 2, 6, 24, 120, 720]
cumprod(a, dtype="float32")
-> [ 1., 2., 6., 24., 120., 720.]
cumprod(a, axis=0) # multiply over rows for each of the 3 columns
-> [[1, 2, 3],
[4, 10, 18]]
cumprod(a, axis=1)
-> [[ 1, 2, 6],
[ 4, 20, 120]]
a = [1, 1, 1, 0, 1, 1, 0] # a is a boolean array
cumprod(a, dtype=int32) # dtype should be provided to get the expected results
-> [1, 1, 1, 0, 0, 0, 0]
"""
return _make.cumprod(data, axis, dtype, exclusive)


def unique(data, is_sorted=True, return_counts=False):
"""
Find the unique elements of a 1-D tensor. Please note `output` and `counts` are all padded to
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from .scatter_add import *
from .argwhere import *
from .interpolate import *
from .cumsum import *
from .scan import *
from .einsum import *
from .unique import *
from . import generic
Expand Down
Loading

0 comments on commit 178d959

Please sign in to comment.