diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 9427dedfe3fa2..310ba20685fca 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -581,6 +582,29 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, oshape.Set(infer_idx, infer_dim); } + // Verify that the sum of dimensions in the output shape is the sum of + // dimensions in the input shape + bool found_dynamic = false; + int64_t oshape_sum = 1; + for(auto& x: oshape) { + if (x.as() != nullptr) { + found_dynamic = true; + break; + } + oshape_sum *= Downcast(x)->value; + } + int64_t data_shape_sum = 1; + for(auto& x: data_shape) { + if (x.as() != nullptr) { + found_dynamic = true; + break; + } + data_shape_sum *= Downcast(x)->value; + } + if (!found_dynamic) { + CHECK_EQ(oshape_sum, data_shape_sum) << "Input tensor shape and reshaped shape are not compatible"; + } + if (param->reverse) { reporter->Assign(types[1], TensorType(Array(oshape.rbegin(), oshape.rend()), data->dtype)); diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 76f10d6c1a182..db45fcbef6cbf 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -21,6 +21,7 @@ import tvm from tvm import te from tvm import relay +from tvm.error import TVMError from tvm.relay import create_executor, transform from tvm.relay.testing import ctx_list, check_grad, run_infer_type @@ -282,6 +283,13 @@ def verify_reshape(shape, newshape, oshape): verify_reshape((2, 3, 4), (2, -4, -1, 3, -2), (2, 1, 3, 4)) +def test_reshape_fail(): + with pytest.raises(TVMError) as reshape_err: + x = relay.var("x", relay.TensorType([2,3], "float32")) + z = relay.reshape(x, [7]) + zz = run_infer_type(z) + + def test_reshape_like_infer_type(): # concrete shape x = relay.var("x", relay.TensorType((1, 2, 3), "float32")) @@ -1070,6 +1078,7 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_ test_transpose() test_reshape_infer_type() test_reshape() + test_reshape_fail() test_reshape_like_infer_type() test_reshape_like() test_take_infer_type()