From c6b766a4cea4e59384c2606deecdc5321ac3d41c Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 30 Dec 2020 17:29:06 -0800 Subject: [PATCH] [Relay][Op] Remove reverse attribute from reshape and reverse_reshape operators. (#7086) --- include/tvm/relay/attrs/transform.h | 4 - src/relay/op/dyn/tensor/transform.cc | 1 - src/relay/op/tensor/transform.cc | 76 ++++++++++++++----- src/relay/op/tensor/transform.h | 2 +- .../test_arm_compute_lib/test_reshape.py | 1 - 5 files changed, 59 insertions(+), 25 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index cbe989f93558..efa44e026c51 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -83,13 +83,9 @@ struct TransposeAttrs : public tvm::AttrsNode { /*! \brief Attributes used in reshape operators */ struct ReshapeAttrs : public tvm::AttrsNode { Array newshape; - bool reverse; TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") { TVM_ATTR_FIELD(newshape).describe( "The new shape. Should be compatible with the original shape."); - TVM_ATTR_FIELD(reverse) - .describe("Infer the special values from right to left if true") - .set_default(false); } }; // struct ReshapeAttrs diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc index 815f24b6bda9..e4e81e3612fb 100644 --- a/src/relay/op/dyn/tensor/transform.cc +++ b/src/relay/op/dyn/tensor/transform.cc @@ -90,7 +90,6 @@ Array ReshapeCompute(const Attrs& attrs, const Array& in Expr MakeReshape(Expr data, Expr newshape) { auto attrs = make_object(); - attrs->reverse = false; static const Op& op = Op::Get("dyn.reshape"); return Call(op, {data, newshape}, Attrs(attrs), {}); } diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 6819ea93f249..19ca6129ecbe 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -455,13 +455,14 @@ RELAY_REGISTER_OP("transpose") TVM_REGISTER_NODE_TYPE(ReshapeAttrs); TVM_REGISTER_NODE_TYPE(ReshapeLikeAttrs); -Array infer_newshape(const Array& data_shape, const Attrs& attrs) { +Array InferNewShape(const Array& data_shape, const Attrs& attrs, + bool reverse) { const auto* param = attrs.as(); Array oshape; Array ishape; Array newshape; - if (param->reverse) { + if (reverse) { ishape.Assign(data_shape.rbegin(), data_shape.rend()); newshape.Assign(param->newshape.rbegin(), param->newshape.rend()); } else { @@ -584,7 +585,6 @@ Array infer_newshape(const Array& data_shape, const Attrs& bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - const auto* param = attrs.as(); // types: [data, result] ICHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -594,16 +594,12 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, return false; } - const auto& oshape = infer_newshape(data->shape, attrs); + const auto& oshape = InferNewShape(data->shape, attrs, false); // Verify that the sum of dimensions in the output shape is the sum of // dimensions in the input shape Array data_shape; - if (param->reverse) { - data_shape.Assign(data->shape.rbegin(), data->shape.rend()); - } else { - data_shape = data->shape; - } + data_shape = data->shape; bool found_dynamic = false; int64_t oshape_sum = 1; @@ -633,12 +629,58 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, << "Input tensor shape and reshaped shape are not compatible"; } - if (param->reverse) { - reporter->Assign(types[1], - TensorType(Array(oshape.rbegin(), oshape.rend()), data->dtype)); - } else { - reporter->Assign(types[1], TensorType(oshape, data->dtype)); + reporter->Assign(types[1], TensorType(oshape, data->dtype)); + return true; +} + +bool ReverseReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // types: [data, result] + ICHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) { + ICHECK(types[0].as()) + << "reshape: expect input type to be TensorType but get " << types[0]; + return false; + } + + const auto& oshape = InferNewShape(data->shape, attrs, true); + + // Verify that the sum of dimensions in the output shape is the sum of + // dimensions in the input shape + Array data_shape; + data_shape.Assign(data->shape.rbegin(), data->shape.rend()); + + bool found_dynamic = false; + int64_t oshape_sum = 1; + for (auto& x : oshape) { + // Check if we have a dynamic shape. If we do, we can't verify if the + // reshape is valid. Dynamic shapes are marker by using Any, but can also + // occur from SizeVar's. In the case of SizeVar, the shape expression can + // be an AST. We can't easily check if we have an AST because of a ShapeVar + // or some other reason, so our check for dynamic shape is just if we can + // convert the shape to in integer or not. + if (!x->IsInstance()) { + found_dynamic = true; + break; + } + oshape_sum *= Downcast(x)->value; } + int64_t data_shape_sum = 1; + for (auto& x : data_shape) { + if (!x->IsInstance()) { + found_dynamic = true; + break; + } + data_shape_sum *= Downcast(x)->value; + } + if (!found_dynamic) { + ICHECK_EQ(oshape_sum, data_shape_sum) + << "Input tensor shape and reshaped shape are not compatible"; + } + + reporter->Assign(types[1], + TensorType(Array(oshape.rbegin(), oshape.rend()), data->dtype)); return true; } @@ -701,7 +743,7 @@ Array ReshapeCompute(const Attrs& attrs, const Array& in } if (newshape_has_any) { - newshape = infer_newshape(inputs[0]->shape, attrs); + newshape = InferNewShape(inputs[0]->shape, attrs, false); } return {topi::reshape(inputs[0], newshape)}; } @@ -709,7 +751,6 @@ Array ReshapeCompute(const Attrs& attrs, const Array& in Expr MakeReshape(Expr data, Array newshape) { auto attrs = make_object(); attrs->newshape = std::move(newshape); - attrs->reverse = false; static const Op& op = Op::Get("reshape"); return Call(op, {data}, Attrs(attrs), {}); } @@ -2871,7 +2912,6 @@ RELAY_REGISTER_OP("auto_scheduler_layout_transform") Expr MakeReverseReshape(Expr data, Array newshape) { auto attrs = make_object(); attrs->newshape = std::move(newshape); - attrs->reverse = true; static const Op& op = Op::Get("contrib_reverse_reshape"); return Call(op, {data}, Attrs(attrs), {}); } @@ -2896,7 +2936,7 @@ example below:: .set_attrs_type() .add_argument("data", "Tensor", "The input tensor.") .set_support_level(10) - .add_type_rel("Reshape", ReshapeRel) + .add_type_rel("ReverseReshape", ReverseReshapeRel) .set_attr("FTVMCompute", ReshapeCompute) .set_attr("TOpPattern", kInjective); diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index 34aaf4689a59..a3770ff9cd8d 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -195,7 +195,7 @@ static inline Array> ConcatenateLayout(const Attrs& attrs, * \param attrs The attributes. * \return Output shape. */ -Array infer_newshape(const Array& data_shape, const Attrs& attrs); +Array InferNewShape(const Array& data_shape, const Attrs& attrs); } // namespace relay } // namespace tvm diff --git a/tests/python/contrib/test_arm_compute_lib/test_reshape.py b/tests/python/contrib/test_arm_compute_lib/test_reshape.py index 9364c6b1a478..94942727416a 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_reshape.py +++ b/tests/python/contrib/test_arm_compute_lib/test_reshape.py @@ -50,7 +50,6 @@ def _get_expected_codegen(input_shape, output_shape, dtype): "newshape": [[str(s) for s in output_shape]], "shape": [[list(output_shape)]], "dtype": [[dtype]], - "reverse": [["0"]], }, }