Skip to content

Commit

Permalink
[FIX] Verify that tensor reshape is valid. (#6215)
Browse files Browse the repository at this point in the history
  • Loading branch information
tkonolige authored Aug 9, 2020
1 parent 9ad33fe commit ae0a062
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 2 deletions.
31 changes: 31 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include <tvm/ir/error.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/tir/data_layout.h>
Expand Down Expand Up @@ -581,6 +582,36 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
oshape.Set(infer_idx, infer_dim);
}

// Verify that the sum of dimensions in the output shape is the sum of
// dimensions in the input shape
bool found_dynamic = false;
int64_t oshape_sum = 1;
for (auto& x : oshape) {
// Check if we have a dynamic shape. If we do, we can't verify if the
// reshape is valid. Dynamic shapes are marker by using Any, but can also
// occur from SizeVar's. In the case of SizeVar, the shape expression can
// be an AST. We can't easily check if we have an AST because of a ShapeVar
// or some other reason, so our check for dynamic shape is just if we can
// convert the shape to in integer or not.
if (!x->IsInstance<tvm::Integer::ContainerType>()) {
found_dynamic = true;
break;
}
oshape_sum *= Downcast<tvm::Integer>(x)->value;
}
int64_t data_shape_sum = 1;
for (auto& x : data_shape) {
if (!x->IsInstance<tvm::Integer::ContainerType>()) {
found_dynamic = true;
break;
}
data_shape_sum *= Downcast<tvm::Integer>(x)->value;
}
if (!found_dynamic) {
CHECK_EQ(oshape_sum, data_shape_sum)
<< "Input tensor shape and reshaped shape are not compatible";
}

if (param->reverse) {
reporter->Assign(types[1],
TensorType(Array<IndexExpr>(oshape.rbegin(), oshape.rend()), data->dtype));
Expand Down
9 changes: 9 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tvm
from tvm import te
from tvm import relay
from tvm.error import TVMError
from tvm.relay import create_executor, transform
from tvm.relay.testing import ctx_list, check_grad, run_infer_type

Expand Down Expand Up @@ -282,6 +283,13 @@ def verify_reshape(shape, newshape, oshape):
verify_reshape((2, 3, 4), (2, -4, -1, 3, -2), (2, 1, 3, 4))


def test_reshape_fail():
with pytest.raises(TVMError) as reshape_err:
x = relay.var("x", relay.TensorType([2,3], "float32"))
z = relay.reshape(x, [7])
zz = run_infer_type(z)


def test_reshape_like_infer_type():
# concrete shape
x = relay.var("x", relay.TensorType((1, 2, 3), "float32"))
Expand Down Expand Up @@ -1070,6 +1078,7 @@ def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_
test_transpose()
test_reshape_infer_type()
test_reshape()
test_reshape_fail()
test_reshape_like_infer_type()
test_reshape_like()
test_take_infer_type()
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relay/test_pass_combine_parallel_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def check(i, j, k, scale1, scale2, newshape):
tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True)

check(3, 5, 4, 0.5, 0.25, (1, 1, 15))
check(100, 200, 300, 0.5, 0.25, (1, 1, 200))
check(100, 200, 300, 0.5, 0.25, (1, 1, 20000))


def test_combine_parallel_dense_flat():
Expand Down Expand Up @@ -369,7 +369,7 @@ def check(i, j, k, scale1, scale2, newshape1, newshape2):
tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True)

check(3, 5, 4, 0.5, 0.25, (1, 1, 15), (1, 1, 30))
check(100, 200, 300, 0.5, 0.25, (1, 1, 200), (1, 1, 400))
check(100, 200, 300, 0.5, 0.25, (1, 1, 20000), (1, 1, 40000))


if __name__ == "__main__":
Expand Down

0 comments on commit ae0a062

Please sign in to comment.