Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay, Topi] [Frontend][TFLite, MXNet] ReverseSequence operator #5495

Merged
merged 18 commits into from
Jun 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ List of operators
topi.reinterpret
topi.transpose
topi.flip
topi.reverse_sequence
topi.strided_slice
topi.expand_dims
topi.reshape
Expand Down Expand Up @@ -152,6 +153,7 @@ topi
.. autofunction:: topi.reinterpret
.. autofunction:: topi.transpose
.. autofunction:: topi.flip
.. autofunction:: topi.reverse_sequence
.. autofunction:: topi.strided_slice
.. autofunction:: topi.expand_dims
.. autofunction:: topi.reshape
Expand Down
1 change: 1 addition & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ This level enables additional math and transform operators.
tvm.relay.repeat
tvm.relay.tile
tvm.relay.reverse
tvm.relay.reverse_sequence
tvm.relay.unravel_index
tvm.relay.sparse_to_dense

Expand Down
14 changes: 14 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,20 @@ struct ReverseAttrs : public tvm::AttrsNode<ReverseAttrs> {
}
}; // struct ReverseAttrs

/*! \brief Attributes used in reverse_sequence operators */
struct ReverseSequenceAttrs : public tvm::AttrsNode<ReverseSequenceAttrs> {
Integer seq_axis;
Integer batch_axis;

TVM_DECLARE_ATTRS(ReverseSequenceAttrs, "relay.attrs.ReverseSequenceAttrs") {
TVM_ATTR_FIELD(seq_axis).set_default(1).describe(
"The seq axis along which to reverse elements.");
TVM_ATTR_FIELD(batch_axis)
.set_default(0)
.describe("The batch axis along which to slice the tensor.");
}
}; // struct ReverseSequenceAttrs

/*! \brief Attributes used in squeeze operators */
struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
// use axis to make the name numpy compatible.
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,21 @@ def _mx_reverse(inputs, attrs):
return _op.reverse(inputs[0], **new_attrs)


def _mx_sequence_reverse(inputs, attrs):
new_attrs = {}
use_seq_lengths = attrs.get_bool("use_sequence_length")
if not use_seq_lengths:
assert len(inputs) == 1
new_attrs["axis"] = attrs.get_int("axis")
return _op.reverse(inputs[0], **new_attrs)

assert len(inputs) == 2
new_attrs["seq_axis"] = attrs.get_int("axis")
# MXNet assumes batch_axis as 1.
new_attrs["batch_axis"] = 1
return _op.reverse_sequence(inputs[0], inputs[1], **new_attrs)


def _mx_roi_align(inputs, attrs):
new_attrs = {}
new_attrs["pooled_size"] = attrs.get_int_tuple("pooled_size")
Expand Down Expand Up @@ -2001,6 +2016,7 @@ def impl(inputs, input_types):
"take" : _mx_take,
"gather_nd" : _mx_gather_nd,
"reverse" : _mx_reverse,
"SequenceReverse" : _mx_sequence_reverse,
"squeeze" : _mx_squeeze,
"broadcast_axis": _mx_broadcast_axis,
"broadcast_axes": _mx_broadcast_axis,
Expand Down
34 changes: 29 additions & 5 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def __init__(self, model, subgraph, exp_tab):
'RESIZE_NEAREST_NEIGHBOR': self.convert_resize_nearest_neighbor,
'ROUND': self.convert_round,
'RSQRT': self.convert_rsqrt,
'REVERSE_SEQUENCE': self.convert_reverse_sequence,
'SELECT': self.convert_select,
'SHAPE': self.convert_shape,
'SIN': self.convert_sin,
Expand Down Expand Up @@ -1868,6 +1869,33 @@ def convert_transpose(self, op):

return out

def convert_reverse_sequence(self, op):
"""Convert TFLite REVERSE_SEQUENCE"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.ReverseSequenceOptions import ReverseSequenceOptions
except ImportError:
raise ImportError("The tflite package must be installed")

if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFLite does not support quantized REVERSE_SEQUENCE operator yet.')

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"

in_expr = self.get_tensor_expr(input_tensors[0])
length_expr = self.get_tensor_expr(input_tensors[1])

assert op.BuiltinOptionsType() == BuiltinOptions.ReverseSequenceOptions
op_options = op.BuiltinOptions()
options = ReverseSequenceOptions()
options.Init(op_options.Bytes, op_options.Pos)
batch_axis = options.BatchDim()
seq_axis = options.SeqDim()

return _op.reverse_sequence(in_expr, length_expr, seq_axis, batch_axis)

def convert_cast(self, op):
"""Convert TFLite CAST"""
try:
Expand Down Expand Up @@ -2566,14 +2594,10 @@ def has_expr(self, input_tensor_idx):
return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx))

def get_tensor_expr(self, tensor):
""" Returns constant expr for constant else a tensor expr"""
""" Return the Relay expr for tensor. """
if self.has_expr(tensor.tensor_idx):
# In most cases, we can assume that TOCO fuses elemwise operators
# with constants - it means both will be tensors.
expr = self.get_expr(tensor.tensor_idx)
else:
# However, in some corner cases, the elemwise operator is not fused,
# we can receive as constant.
type_str = self.get_tensor_type_str(tensor.tensor.Type())
expr = self.exp_tab.new_const(self.get_tensor_value(tensor), dtype=type_str)

Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_reg.register_injective_schedule("full_like")
_reg.register_injective_schedule("arange")
_reg.register_injective_schedule("reverse")
_reg.register_injective_schedule("reverse_sequence")
_reg.register_injective_schedule("cast")
_reg.register_injective_schedule("cast_like")
_reg.register_injective_schedule("reinterpret")
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ class TileAttrs(Attrs):
class ReverseAttrs(Attrs):
"""Attributes used in reverse operators"""

@tvm._ffi.register_object("relay.attrs.ReverseSequenceAttrs")
class ReverseSequenceAttrs(Attrs):
"""Attributes used in reverse sequence operators"""


@tvm._ffi.register_object("relay.attrs.SqueezeAttrs")
class SqueezeAttrs(Attrs):
Expand Down
47 changes: 47 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,53 @@ def reverse(data, axis):
return _make.reverse(data, axis)


def reverse_sequence(data, seq_lengths, seq_axis=1, batch_axis=0):
"""Reverse the tensor for variable length slices.
Input is first sliced along batch axis and then elements are reversed along seq axis.

Parameters
----------
data : relay.Expr
The tensor to be reversed.

seq_lengths : relay.Expr
A 1D Tensor with length a.dims[batch_axis]
Must be one of the following types: int32, int64
if seq_lengths[i] > a.dims[seq_axis], it is rounded to a.dims[seq_axis]
if seq_lengths[i] < 1, it is rounded to 1

seq_axis : int, optional
The axis along which the elements will be reversed. Default is 1.

batch_axis : int, optional
The axis along which the tensor will be sliced. Default is 0.

Returns
-------
ret : relay.Expr
The computed result of same shape and type as of input.

Examples
--------
.. code-block:: python

x = [[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11],
[12, 13, 14, 15]]
relay.reverse(x, [1, 2, 3, 4], 0, 1) = [[0, 5, 10, 15],
[4, 1, 6, 11],
[8, 9, 2, 7],
[12, 13, 14, 3]]

relay.reverse(x, [1, 2, 3, 4], 1, 0) = [[0, 1, 2, 3],
[5, 4, 6, 7],
[10, 9, 8, 11],
[15, 14, 13, 12]]
"""
return _make.reverse_sequence(data, seq_lengths, seq_axis, batch_axis)


def where(condition, x, y):
"""Selecting elements from either x or y depending on the value of the
condition.
Expand Down
93 changes: 92 additions & 1 deletion src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1397,7 +1397,8 @@ Array<te::Tensor> ReverseCompute(const Attrs& attrs, const Array<te::Tensor>& in
const Type& out_type) {
const ReverseAttrs* param = attrs.as<ReverseAttrs>();
CHECK(param != nullptr);
return {topi::flip(inputs[0], param->axis)};
// pass empty seq_length tensor to reverse_sequence
return {topi::reverse_sequence(inputs[0], te::Tensor(), param->axis)};
}

Expr MakeReverse(Expr data, int axis) {
Expand All @@ -1423,6 +1424,96 @@ RELAY_REGISTER_OP("reverse")
.set_attr<FTVMCompute>("FTVMCompute", ReverseCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// reverse sequence operator
TVM_REGISTER_NODE_TYPE(ReverseSequenceAttrs);

bool ReverseSequenceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, seq_lengths, result]
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();

if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "reverse_sequence: expect input type to be TensorType but get " << types[0];
return false;
}

const auto* seq_lengths = types[1].as<TensorTypeNode>();
if (seq_lengths == nullptr) {
CHECK(types[1].as<IncompleteTypeNode>())
<< "reverse_sequence: expect input type to be TensorType but get " << types[1];
return false;
}

const int seq_lengths_dim = static_cast<int>(seq_lengths->shape.size());
CHECK(seq_lengths_dim == 1) << "For reverse_sequnece, seq_lengths must be a 1D vector";
CHECK(seq_lengths->dtype.is_int())
<< "For reverse_sequnece, seq_lengths must be tensor of integer";

const auto* param = attrs.as<ReverseSequenceAttrs>();
const int ndim = static_cast<int>(data->shape.size());
int batch_axis = param->batch_axis;
CHECK(-ndim <= batch_axis && batch_axis < ndim)
<< "reverse_sequence only accepts `batch_axis` in [-data.ndim, data.ndim - 1]"
<< ", but got batch_axis = " << batch_axis << ", and data.ndim = " << ndim;

if (batch_axis < 0) {
batch_axis = static_cast<int>(data->shape.size()) + batch_axis;
}
CHECK(reporter->Assert(seq_lengths->shape[0] == data->shape[batch_axis]))
<< "For reverse_sequnece seq_lengths size should match with dimension of batch axis"
<< ", but got dimension of batch_axis = " << data->shape[batch_axis]
<< ", and seq_length size = " << seq_lengths->shape[0];

const int seq_axis = param->seq_axis;
CHECK(-ndim <= seq_axis && seq_axis < ndim)
<< "reverse_sequnece only accepts `seq_axis` in [-data.ndim, data.ndim - 1]"
<< ", but got seq_axis = " << seq_axis << ", and data.ndim = " << ndim;

reporter->Assign(types[2], types[0]);
return true;
}

Array<te::Tensor> ReverseSequenceCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const ReverseSequenceAttrs* param = attrs.as<ReverseSequenceAttrs>();
CHECK(param != nullptr);
return {topi::reverse_sequence(inputs[0], inputs[1], param->seq_axis, param->batch_axis)};
}

Expr MakeReverseSequence(Expr data, Expr seq_lengths, int seq_axis, int batch_axis) {
auto attrs = make_object<ReverseSequenceAttrs>();
attrs->seq_axis = seq_axis;
attrs->batch_axis = batch_axis;
static const Op& op = Op::Get("reverse_sequence");
return Call(op, {data, seq_lengths}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.reverse_sequence").set_body_typed(MakeReverseSequence);

RELAY_REGISTER_OP("reverse_sequence")
.describe(R"code(Reverses the tensor for variable length slices.
Input is first sliced along batch axis and then elements are reversed along seq axis.

- **data**: The input data to the operator.

- **seq_lengths**: A 1D Tensor with length data.dims[batch_axis].

- **seq_axis**: The axis along which the elements will be reversed. Default is 1.

- **batch_axis**: The axis along which the tensor will be sliced. Default is 0.

)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.set_attrs_type<ReverseSequenceAttrs>()
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("seq_lengths", "Tensor", "A 1D Tensor with length data.dims[batch_axis]")
.set_support_level(3)
.add_type_rel("ReverseSequence", ReverseSequenceRel)
.set_attr<FTVMCompute>("FTVMCompute", ReverseSequenceCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// where operator
bool WhereRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
Expand Down
34 changes: 34 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,39 @@ def verify(x_shape, y_shape, axes):
verify((3, 4), (2, 3), (0))
verify((3, 4), (2, 3), (-1))

def test_forward_sequence_reverse():
def verify(shape, seq_lengths, use_seq_lengths, seq_axis):
data_np = np.random.uniform(size=shape).astype("float32")

ref_res_args = [mx.nd.array(data_np), None, use_seq_lengths, seq_axis]
mx_sym_args = [mx.sym.var("data"), None, use_seq_lengths, seq_axis]
from_mxnet_args = [{"data": shape}, {"data": "float32"}]
in_data= [data_np]

if use_seq_lengths and seq_lengths:
seq_lengths_np = np.array(seq_lengths).astype("int32")
ref_res_args[1] = mx.nd.array(seq_lengths_np)
mx_sym_args[1] = mx.sym.var("seq_lengths")
from_mxnet_args[0].update({"seq_lengths": seq_lengths_np.shape})
from_mxnet_args[1].update({"seq_lengths": "int32"})
in_data.append(seq_lengths_np)

ref_res = mx.nd.SequenceReverse(*ref_res_args)
mx_sym = mx.sym.SequenceReverse(*mx_sym_args)
mod, _ = relay.frontend.from_mxnet(mx_sym, *from_mxnet_args)

for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(*in_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())

verify((3, 4), [1, 2, 3, 1], True, 0)
verify((3, 4), None, False, 0)
verify((3, 5, 5, 6), [1, 2, 3, 1, 3], True, 0)
# MXNet accepts axis value as 0 only
# verify((3, 4, 5, 6), None, False, 2)

def test_forward_l2_normalize():
data = mx.sym.var('data')
mx_sym = mx.sym.L2Normalization(data, mode="channel")
Expand Down Expand Up @@ -1228,6 +1261,7 @@ def verify(data_shape, kernel_size, max_displacement, stride1, stride2, pad_size
test_forward_scalar_ops()
test_forward_slice_like()
test_forward_slice_axis()
test_forward_sequence_reverse()
test_forward_l2_normalize()
test_forward_shape_array()
test_forward_squeeze()
Expand Down
Loading