From de2934017191880bd2e10d98d0f613d4fb52fc53 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 22 Feb 2021 13:04:35 +0900 Subject: [PATCH 1/9] enforce static dim for non-concat axis --- src/relay/op/tensor/transform.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index 95a83a905908..edcc3ae131f5 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -116,7 +116,6 @@ bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs } } int non_any_size = static_cast(non_any.size()); - if (non_any_size != data_length) oshape[i] = Any(); if (i != axis) { for (int k = 1; k < non_any_size; k++) { if (reporter->AssertEQ(non_any[0], non_any[k])) continue; @@ -124,6 +123,10 @@ bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs "relay.concatenate requires all tensors have the same shape " "on non-concatenating axes"); } + if (non_any_size > 0) { + // For non concat-axes, enforce static shape constraint + oshape[i] = non_any[0]; + } } } From b9288eec36e175f705e5b63507825786a1cfdef9 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 22 Feb 2021 18:37:23 +0900 Subject: [PATCH 2/9] assign any when all dims are dyn --- src/relay/op/tensor/transform.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index edcc3ae131f5..c65d372da7c2 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -124,8 +124,10 @@ bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs "on non-concatenating axes"); } if (non_any_size > 0) { - // For non concat-axes, enforce static shape constraint + // For non-concat axes, enforce static shape constraint oshape[i] = non_any[0]; + } else { + oshape[i] = Any(); } } } From 26a20ee6fc58b8331452c1bc251177b387189c5b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 22 Feb 2021 19:17:43 +0900 Subject: [PATCH 3/9] add missing case --- src/relay/op/tensor/transform.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index c65d372da7c2..c37831762cd8 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -129,6 +129,9 @@ bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs } else { oshape[i] = Any(); } + } else if (non_any_size != data_length) { + // For concat axis, if there is one any among input dims, the output dim is dynamic. + oshape[i] = Any(); } } From b8722d9d825ba84bfedeb25b754881900921670d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 22 Feb 2021 19:23:32 +0900 Subject: [PATCH 4/9] simplify --- src/relay/op/tensor/transform.h | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index c37831762cd8..a89bbdf82b67 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -116,21 +116,16 @@ bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs } } int non_any_size = static_cast(non_any.size()); - if (i != axis) { + if (i != axis && non_any_size > 0) { for (int k = 1; k < non_any_size; k++) { if (reporter->AssertEQ(non_any[0], non_any[k])) continue; throw Error( "relay.concatenate requires all tensors have the same shape " "on non-concatenating axes"); } - if (non_any_size > 0) { - // For non-concat axes, enforce static shape constraint - oshape[i] = non_any[0]; - } else { - oshape[i] = Any(); - } + // For non-concat axes, enforce static shape constraint + oshape[i] = non_any[0]; } else if (non_any_size != data_length) { - // For concat axis, if there is one any among input dims, the output dim is dynamic. oshape[i] = Any(); } } From cf5dd2700cf4468d9c41954372041a876a04c07c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 22 Feb 2021 19:38:52 +0900 Subject: [PATCH 5/9] add test --- tests/python/relay/test_any.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 9d05631a753a..fa581778b7db 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -208,6 +208,16 @@ def test_any_concat(): ref = np.concatenate(x_np, axis=0) check_result(x_np, mod, ref) + x = [relay.var("x", shape=(relay.Any(), 3), dtype="float32") for _ in range(3)] + x.append(relay.var("x", shape=(relay.Any(), relay.Any()), dtype="float32")) + z = relay.op.concatenate(x, axis=0) + mod = tvm.IRModule() + mod["main"] = relay.Function(x, z) + typed_mod = relay.transform.InferType()(mod) + assert typed_mod["main"].body.checked_type == relay.TensorType( + (relay.Any(), 3), dtype="float32" + ) + def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape, variable_newshape=False): x = relay.var("x", shape=x_shape, dtype="float32") From 4651916a832dc32ce6af55eda55d513fd9783126 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 25 Feb 2021 20:14:03 +0900 Subject: [PATCH 6/9] only enforce static dim constraint if concat output is dynamic --- src/relay/op/tensor/transform.h | 54 ++++++++++++++++++++++++--------- 1 file changed, 39 insertions(+), 15 deletions(-) diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index a89bbdf82b67..0e1d49d930e3 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -103,29 +103,53 @@ bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs // Calculate shape std::vector oshape(first->shape.begin(), first->shape.end()); int data_length = static_cast(tensor_tuple->fields.size()); + // Decide if this is dynamic concat + bool is_dynamic_concat = false; + std::vector input_tensors; + for (int i = 0; i < data_length; ++i) { + const auto& e = Downcast(tensor_tuple->fields[i]); + input_tensors.push_back(e); + if (e->shape[axis].as()) { + is_dynamic_concat = true; + oshape[axis] = Any(); + break; + } else { + // accumulate axis dimension + if (i > 0 && !oshape[axis].as()) { + oshape[axis] += e->shape[i]; + } + } + } + for (int i = 0; i < ndim; ++i) { + if (i == axis) continue; std::vector non_any; for (int j = 0; j < data_length; ++j) { - const auto& e = Downcast(tensor_tuple->fields[j]); + const auto& e = input_tensors[j]; if (!e->shape[i].as()) { non_any.push_back(e->shape[i]); - // accumulate axis dimension - if (j > 0 && i == axis && !oshape[i].as()) { - oshape[i] += e->shape[i]; - } } } - int non_any_size = static_cast(non_any.size()); - if (i != axis && non_any_size > 0) { - for (int k = 1; k < non_any_size; k++) { - if (reporter->AssertEQ(non_any[0], non_any[k])) continue; - throw Error( - "relay.concatenate requires all tensors have the same shape " - "on non-concatenating axes"); - } - // For non-concat axes, enforce static shape constraint + size_t non_any_size = non_any.size(); + for (size_t k = 1; k < non_any_size; k++) { + if (reporter->AssertEQ(non_any[0], non_any[k])) continue; + throw Error( + "relay.concatenate requires all tensors have the same shape " + "on non-concatenating axes"); + } + + if (non_any_size > 0 && is_dynamic_concat) { + // For non-concat axes, we want to enforce static shape constraint. + // However, if the concat axis is static, the output shape would become static while + // the input could be partially static/dynamic. To prevent runtime segfaults due to the lack + // of runtime input shape checking for such cases, static shape constraint is only enforced + // when the output shape is dynamic. + // + // Examples (both concat on the first axis): + // * [(?, 3), (?, ?)] -> (?, 3) + // * [(1, 3), (1, ?)] -> (2, ?) oshape[i] = non_any[0]; - } else if (non_any_size != data_length) { + } else { oshape[i] = Any(); } } From f7844a2d7dc920f3d642d33d2cc2d980b9cc1f67 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 25 Feb 2021 20:36:59 +0900 Subject: [PATCH 7/9] more update to concat type rel --- src/relay/op/tensor/transform.h | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index 0e1d49d930e3..75ea8511c144 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -101,28 +101,33 @@ bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs } // Calculate shape - std::vector oshape(first->shape.begin(), first->shape.end()); - int data_length = static_cast(tensor_tuple->fields.size()); - // Decide if this is dynamic concat + std::vector oshape(ndim); + const size_t data_length = tensor_tuple->fields.size(); + + // Accumulate the concat axis output dim or decide if this is dynamic concat bool is_dynamic_concat = false; std::vector input_tensors; - for (int i = 0; i < data_length; ++i) { + IndexExpr concat_output_dim = first->shape[axis]; + for (size_t i = 0; i < data_length; ++i) { const auto& e = Downcast(tensor_tuple->fields[i]); input_tensors.push_back(e); if (e->shape[axis].as()) { is_dynamic_concat = true; - oshape[axis] = Any(); - break; - } else { + concat_output_dim = Any(); + } else if (i > 0 && !is_dynamic_concat) { // accumulate axis dimension - if (i > 0 && !oshape[axis].as()) { - oshape[axis] += e->shape[i]; - } + concat_output_dim += e->shape[axis]; } } + oshape[axis] = concat_output_dim; + for (int i = 0; i < ndim; ++i) { - if (i == axis) continue; + if (i == axis) { + // The concat axis is already handled above. + // The rest of the body sets the output shape for non-concat axes + continue; + } std::vector non_any; for (int j = 0; j < data_length; ++j) { const auto& e = input_tensors[j]; @@ -138,7 +143,10 @@ bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs "on non-concatenating axes"); } - if (non_any_size > 0 && is_dynamic_concat) { + if (non_any_size == data_length) { + // All static case + oshape[i] = non_any[0]; + } else if (non_any_size > 0 && is_dynamic_concat) { // For non-concat axes, we want to enforce static shape constraint. // However, if the concat axis is static, the output shape would become static while // the input could be partially static/dynamic. To prevent runtime segfaults due to the lack From 83f698ea010f72bfea4da964204cf23442bfd7bd Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 25 Feb 2021 21:03:51 +0900 Subject: [PATCH 8/9] update tests --- src/relay/op/tensor/transform.h | 2 +- tests/python/relay/test_any.py | 25 ++++++++++++++++++------- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index 75ea8511c144..273700e71e66 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -151,7 +151,7 @@ bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs // However, if the concat axis is static, the output shape would become static while // the input could be partially static/dynamic. To prevent runtime segfaults due to the lack // of runtime input shape checking for such cases, static shape constraint is only enforced - // when the output shape is dynamic. + // when the output concat axis is dynamic. // // Examples (both concat on the first axis): // * [(?, 3), (?, ?)] -> (?, 3) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index fa581778b7db..b75cc5f5e750 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -208,15 +208,26 @@ def test_any_concat(): ref = np.concatenate(x_np, axis=0) check_result(x_np, mod, ref) + def test_oshape(in_vars, axis, oshape): + z = relay.op.concatenate(in_vars, axis=axis) + mod = tvm.IRModule() + mod["main"] = relay.Function(in_vars, z) + typed_mod = relay.transform.InferType()(mod) + assert typed_mod["main"].body.checked_type == relay.TensorType(oshape, dtype="float32") + x = [relay.var("x", shape=(relay.Any(), 3), dtype="float32") for _ in range(3)] x.append(relay.var("x", shape=(relay.Any(), relay.Any()), dtype="float32")) - z = relay.op.concatenate(x, axis=0) - mod = tvm.IRModule() - mod["main"] = relay.Function(x, z) - typed_mod = relay.transform.InferType()(mod) - assert typed_mod["main"].body.checked_type == relay.TensorType( - (relay.Any(), 3), dtype="float32" - ) + + test_oshape(x, 0, (relay.Any(), 3)) + test_oshape(x, 1, (relay.Any(), relay.Any())) + + # [(1, 3), (1, ?)] -> (2, ?) + x = [ + relay.var("x", shape=(1, 3), dtype="float32"), + relay.var("x", shape=(1, relay.Any()), dtype="float32"), + ] + test_oshape(x, 0, (2, relay.Any())) + test_oshape(x, 1, (1, relay.Any())) def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape, variable_newshape=False): From 5d667031e59ccc2811bd9b155611a7398adc3577 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 26 Feb 2021 04:09:49 +0900 Subject: [PATCH 9/9] fixed compile warning --- src/relay/op/tensor/transform.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index 273700e71e66..dbf8537e0dad 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -129,7 +129,7 @@ bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs continue; } std::vector non_any; - for (int j = 0; j < data_length; ++j) { + for (size_t j = 0; j < data_length; ++j) { const auto& e = input_tensors[j]; if (!e->shape[i].as()) { non_any.push_back(e->shape[i]);