Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY][OP] Support MXNet-style attributes for reshape_like #6851

Merged
merged 5 commits into from
Nov 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,26 @@ struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
}
}; // struct ReshapeAttrs

/*! \brief Attributes used in MXNet-style reshape_like operators */
struct ReshapeLikeAttrs : public tvm::AttrsNode<ReshapeLikeAttrs> {
int lhs_begin;
Integer lhs_end; // can be None
int rhs_begin;
Integer rhs_end; // can be None
TVM_DECLARE_ATTRS(ReshapeLikeAttrs, "relay.attrs.ReshapeLikeAttrs") {
TVM_ATTR_FIELD(lhs_begin).set_default(0).describe(
"The axis of the input where reshaping should begin.");
TVM_ATTR_FIELD(lhs_end)
.set_default(NullValue<Integer>())
.describe("The axis of the input where reshaping should end, exclusive.");
TVM_ATTR_FIELD(rhs_begin).set_default(0).describe(
"The axis of the shape_like tensor to begin taking dimensions from.");
TVM_ATTR_FIELD(rhs_end)
.set_default(NullValue<Integer>())
.describe("The axis of the shape_like tensor to end taking dimensions from, exclusive.");
}
}; // struct ReshapeLikeAttrs

struct ScatterAttrs : public tvm::AttrsNode<ScatterAttrs> {
Integer axis;

Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ class ReshapeAttrs(Attrs):
"""Attributes for transform.reshape"""


@tvm._ffi.register_object("relay.attrs.ReshapeLikeAttrs")
class ReshapeLikeAttrs(Attrs):
"""Attributes for transform.reshape_like"""


@tvm._ffi.register_object("relay.attrs.GatherAttrs")
class GatherAttrs(Attrs):
"""Attributes for transform.gather"""
Expand Down
43 changes: 35 additions & 8 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,28 +308,55 @@ def scatter_add(data, indices, updates, axis):
return _make.scatter_add(data, indices, updates, axis)


def reshape_like(data, shape_like):
"""Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
the input array into an output array with the same shape as the second input array.
def reshape_like(data, shape_like, lhs_begin=0, lhs_end=None, rhs_begin=0, rhs_end=None):
altanh marked this conversation as resolved.
Show resolved Hide resolved
"""Reshapes the input tensor by the size of another tensor.
altanh marked this conversation as resolved.
Show resolved Hide resolved
For an input tensor with shape ``(d0, d1, ..., d(k-1))``, `reshape_like` operation reshapes
the input tensor into an output tensor with the same shape as the second input tensor,
in particular reshaping the dimensions of `data` in `[lhs_begin, lhs_end)` using the dimensions
from `shape_like` in `[rhs_begin, rhs_end)`.

.. note::
Sizes for both array should be compatible.
Sizes for `data` and the output tensor should be compatible.

Parameters
----------
data : relay.Expr
The input data to the operator.

shape_like : tuple of int
The new shape. Should be compatible with the original shape.
shape_like : relay.Expr
The tensor to reshape data like. Should be compatible with the original shape on the
reshaped dimensions.

lhs_begin : int, optional
The axis of data to begin reshaping. Default is 0.

lhs_end : int or None, optional
The axis of data where reshaping should stop, exclusive. Default is None which reshapes to
the end.

rhs_begin : int, optional
The axis of shape_like where the target shape begins. Default is 0.

rhs_end : int or None, optional
The axis of shape_like where the target shape ends, exclusive. Default is None which extends
to the end.

Returns
-------
ret : relay.Expr
The computed result.

Examples
--------
.. code-block:: python

data.shape == (1, 2, 3, 4)
shape_like.shape == (6, 2, 2, 3)

ret = relay.reshape_like(data, shape_like, lhs_begin=1, rhs_end=3)
ret.shape == (1, 6, 2, 2)
"""
return _make.reshape_like(data, shape_like)
return _make.reshape_like(data, shape_like, lhs_begin, lhs_end, rhs_begin, rhs_end)


def take(data, indices, axis=None, mode="clip"):
Expand Down
3 changes: 3 additions & 0 deletions src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ Expr MakeRepeat(Expr data, int repeats, int axis);

Expr MakeReshape(Expr data, Array<Integer> newshape);

Expr MakeReshapeLike(Expr lhs, Expr rhs, int lhs_begin, Integer lhs_end, int rhs_begin,
Integer rhs_end);

Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis);

Expr MakeSqueeze(Expr data, Array<Integer> axis);
Expand Down
66 changes: 61 additions & 5 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ RELAY_REGISTER_OP("transpose")

/* relay.reshape */
TVM_REGISTER_NODE_TYPE(ReshapeAttrs);
TVM_REGISTER_NODE_TYPE(ReshapeLikeAttrs);

Array<IndexExpr> infer_newshape(const Array<IndexExpr>& data_shape, const Attrs& attrs) {
const auto* param = attrs.as<ReshapeAttrs>();
Expand Down Expand Up @@ -641,11 +642,49 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
return true;
}

Array<PrimExpr> infer_reshape_like(const Array<PrimExpr>& lhs_shape,
const Array<PrimExpr>& rhs_shape, const Attrs& attrs) {
const auto* like_attrs = attrs.as<ReshapeLikeAttrs>();
CHECK(!like_attrs->lhs_end.defined() || like_attrs->lhs_end.as<IntImmNode>())
altanh marked this conversation as resolved.
Show resolved Hide resolved
<< "lhs_end must be a concrete integer or None";
CHECK(!like_attrs->rhs_end.defined() || like_attrs->rhs_end.as<IntImmNode>())
<< "rhs_end must be a concrete integer or None";

int64_t lhs_shape_size = static_cast<int64_t>(lhs_shape.size());
int64_t rhs_shape_size = static_cast<int64_t>(rhs_shape.size());
int64_t lhs_begin = static_cast<int64_t>(like_attrs->lhs_begin);
int64_t lhs_end =
like_attrs->lhs_end.defined() ? like_attrs->lhs_end.as<IntImmNode>()->value : lhs_shape_size;
int64_t rhs_begin = static_cast<int64_t>(like_attrs->rhs_begin);
int64_t rhs_end =
like_attrs->rhs_end.defined() ? like_attrs->rhs_end.as<IntImmNode>()->value : rhs_shape_size;

// handle negative axes
lhs_begin = lhs_begin < 0 ? lhs_begin + lhs_shape_size : lhs_begin;
lhs_end = lhs_end < 0 ? lhs_end + lhs_shape_size : lhs_end;
rhs_begin = rhs_begin < 0 ? rhs_begin + rhs_shape_size : rhs_begin;
rhs_end = rhs_end < 0 ? rhs_end + rhs_shape_size : rhs_end;

Array<PrimExpr> shape_like;
for (auto i = 0; i < lhs_begin; i++) {
shape_like.push_back(lhs_shape[i]);
}
for (auto i = rhs_begin; i < rhs_end; i++) {
shape_like.push_back(rhs_shape[i]);
}
for (auto i = lhs_end; i < lhs_shape_size; i++) {
shape_like.push_back(lhs_shape[i]);
}
return shape_like;
}

Array<te::Tensor> ReshapeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
// Quick path for reshape_like
if (!attrs.as<ReshapeAttrs>()) {
return {topi::reshape(inputs[0], inputs[1]->shape)};
ICHECK(attrs.as<ReshapeLikeAttrs>() != nullptr);
auto shape_like = infer_reshape_like(inputs[0]->shape, inputs[1]->shape, attrs);
return {topi::reshape(inputs[0], shape_like)};
}

const auto* out_ttype = out_type.as<TensorTypeNode>();
Expand Down Expand Up @@ -746,6 +785,7 @@ Example::
*/
bool ReshapeLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
ICHECK(attrs.as<ReshapeLikeAttrs>() != nullptr);
ICHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
Expand All @@ -755,6 +795,7 @@ bool ReshapeLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
if (reshape_like == nullptr) {
return false;
}
auto shape_like = infer_reshape_like(data->shape, reshape_like->shape, attrs);
// Only check When input data has static shape.
bool is_static_shape = true;
for (size_t i = 0; i < data->shape.size(); ++i) {
Expand All @@ -763,17 +804,24 @@ bool ReshapeLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
break;
}
}
auto output_type = TensorType(shape_like, data->dtype);
if (is_static_shape) {
ICHECK(reporter->AssertEQ(data->Size(), reshape_like->Size()))
ICHECK(reporter->AssertEQ(data->Size(), output_type->Size()))
<< "Reshape inputs size should be compatible.";
}
reporter->Assign(types[2], TensorType(reshape_like->shape, data->dtype));
reporter->Assign(types[2], output_type);
return true;
}

Expr MakeReshapeLike(Expr data, Expr shape_like) {
Expr MakeReshapeLike(Expr lhs, Expr rhs, int lhs_begin, Integer lhs_end, int rhs_begin,
Integer rhs_end) {
auto attrs = make_object<ReshapeLikeAttrs>();
attrs->lhs_begin = std::move(lhs_begin);
attrs->lhs_end = std::move(lhs_end);
attrs->rhs_begin = std::move(rhs_begin);
attrs->rhs_end = std::move(rhs_end);
static const Op& op = Op::Get("reshape_like");
return Call(op, {data, shape_like}, Attrs(), {});
return Call(op, {lhs, rhs}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.reshape_like").set_body_typed(MakeReshapeLike);
Expand All @@ -784,7 +832,15 @@ For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation re
the input array into an output array with the same shape as the second input array.
.. note::
Sizes for both array should be compatible.
Example::

data.shape == (1, 2, 3, 4)
shape_like.shape == (6, 2, 2, 3)

ret = reshape_like(data, shape_like, lhs_begin=1, rhs_end=3)
ret.shape == (1, 6, 2, 2)
)code" TVM_ADD_FILELINE)
.set_attrs_type<ReshapeLikeAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("shape_like", "Tensor", "Shape tensor.")
Expand Down
6 changes: 3 additions & 3 deletions src/relay/transforms/pattern_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -594,9 +594,9 @@ inline Expr LeftShift(Expr x, Expr nbit) {
return Call(op, {x, nbit}, Attrs(), {});
}

inline Expr ReshapeLike(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("reshape_like");
return Call(op, {lhs, rhs}, Attrs(), {});
inline Expr ReshapeLike(Expr lhs, Expr rhs, int lhs_begin, Integer lhs_end, int rhs_begin,
Integer rhs_end) {
return MakeReshapeLike(lhs, rhs, lhs_begin, lhs_end, rhs_begin, rhs_end);
}

inline Expr Copy(Expr data) {
Expand Down
41 changes: 36 additions & 5 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,17 +316,45 @@ def test_reshape_like_infer_type():
zz = run_infer_type(z)
assert zz.checked_type == relay.TensorType((1, 8, 8), "float32")

# partial reshaping
x = relay.var("x", relay.TensorType((1, 2, 3, 4), "float32"))
y = relay.var("y", relay.TensorType((1, 6, 5), "float32"))
z = relay.reshape_like(x, y, lhs_begin=1, lhs_end=3, rhs_begin=1, rhs_end=2)
zz = run_infer_type(z)
assert zz.checked_type == relay.TensorType((1, 6, 4), "float32")

x = relay.var("x", relay.TensorType((1, 2, 3, 4), "float32"))
y = relay.var("y", relay.TensorType((2, 3, 4, 1, 6), "float32"))
z = relay.reshape_like(x, y, rhs_end=3)
zz = run_infer_type(z)
assert zz.checked_type == relay.TensorType((2, 3, 4), "float32")
z = relay.reshape_like(x, y, rhs_begin=2)
zz = run_infer_type(z)
assert zz.checked_type == relay.TensorType((4, 1, 6), "float32")

# symbolic partial reshaping
n, c, h, w = te.size_var("n"), 2, 3, te.size_var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
y = relay.var("y", relay.TensorType((5, 6), "float32"))
z = relay.var("z", relay.TensorType((4,), "float32"))
w = relay.reshape_like(x, y, lhs_end=3)
w = relay.reshape_like(w, z, lhs_begin=2)
altanh marked this conversation as resolved.
Show resolved Hide resolved
w = run_infer_type(w)
assert w.checked_type == relay.TensorType((5, 6, 4), "float32")


@tvm.testing.uses_gpu
def test_reshape_like():
def verify_reshape_like(shape, oshape):
def verify_reshape_like(shape, oshape, shape_like=None, reshape_like_kwargs={}):
if shape_like is None:
shape_like = oshape
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
y_data = np.random.uniform(low=-1, high=1, size=oshape).astype("float32")
ref_res = np.reshape(x_data, y_data.shape)
y_data = np.random.uniform(low=-1, high=1, size=shape_like).astype("float32")
ref_res = np.reshape(x_data, oshape)

x = relay.var("x", relay.TensorType(shape, "float32"))
y = relay.var("x", relay.TensorType(oshape, "float32"))
z = relay.reshape_like(x, y)
y = relay.var("x", relay.TensorType(shape_like, "float32"))
z = relay.reshape_like(x, y, **reshape_like_kwargs)
zz = run_infer_type(z)
assert zz.checked_type == relay.ty.TensorType(ref_res.shape, "float32")

Expand All @@ -340,6 +368,9 @@ def verify_reshape_like(shape, oshape):

verify_reshape_like((2, 3, 4), (1, 8, 3))
verify_reshape_like((4, 7), (2, 7, 2))
verify_reshape_like(
(1, 2, 3, 4), (1, 6, 4), (1, 6, 5), dict(lhs_begin=1, lhs_end=3, rhs_begin=1, rhs_end=2)
)


def test_take_infer_type():
Expand Down