From 5782b7070288eb0de122f5dab91b38c26166a7d7 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 27 May 2021 08:31:11 +0900 Subject: [PATCH] support layout transform part1 --- src/relay/op/tensor/transform.cc | 69 ++++++++++++------- .../relay/test_pass_convert_op_layout.py | 44 ++++++++++++ 2 files changed, 89 insertions(+), 24 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 7265520bade1..d156b2a0b7fb 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2457,6 +2457,7 @@ bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attr Array axes; if (param->axes) { axes = param->axes.value(); + LOG(INFO) << axes.size() << ", " << begin.size() << ", " << end.size() << ", " << strides.size(); ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size()) << "axes, begin, end, and strides must have the same length"; @@ -2537,34 +2538,54 @@ Array> StridedSliceInferCorrectLayout(const Attrs& attrs, // Not support NHW4c -> NCHW return {{Layout::Undef()}, {Layout::Undef()}}; } else { - for (size_t i = 0; i < new_layout_name.size(); ++i) { - auto index = layout.IndexOf(new_layout[i]); - if (index == -1) { - return {{Layout::Undef()}, {Layout::Undef()}}; + if (params->axes) { + auto axes = params->axes.value(); + new_begin.resize(axes.size()); + new_end.resize(axes.size()); + new_strides.resize(axes.size()); + Array new_axes; + + for (size_t i = 0; i < axes.size(); ++i) { + auto old_idx = axes[i]; + auto new_idx = new_layout.IndexOf(layout[old_idx]); + new_begin.Set(new_idx, begin[i]); + new_end.Set(new_idx, end[i]); + new_strides.Set(new_idx, strides[i]); + new_axes.push_back(new_idx); } + params->axes = new_axes; - size_t new_index = static_cast(index); - int64_t bg, ed, st; - if (strides.defined() && new_index < strides.size() && strides[new_index].defined()) { - st = strides[new_index]->value; - } else { - st = 1; - } - if (new_index < begin.size() && begin[new_index].defined()) { - bg = begin[new_index]->value; - } else { - bg = 0; - } - if (new_index < end.size() && end[new_index].defined()) { - ed = end[new_index]->value; - } else { - ed = shape[new_index].as()->value; - } + } else { + for (size_t i = 0; i < new_layout_name.size(); ++i) { + auto index = layout.IndexOf(new_layout[i]); + if (index == -1) { + return {{Layout::Undef()}, {Layout::Undef()}}; + } - new_begin.push_back(IntImm(begin[0]->dtype, bg)); - new_end.push_back(IntImm(end[0]->dtype, ed)); - new_strides.push_back(IntImm(strides[0]->dtype, st)); + size_t new_index = static_cast(index); + int64_t bg, ed, st; + if (strides.defined() && new_index < strides.size() && strides[new_index].defined()) { + st = strides[new_index]->value; + } else { + st = 1; + } + if (new_index < begin.size() && begin[new_index].defined()) { + bg = begin[new_index]->value; + } else { + bg = 0; + } + if (new_index < end.size() && end[new_index].defined()) { + ed = end[new_index]->value; + } else { + ed = shape[new_index].as()->value; + } + + new_begin.push_back(IntImm(begin[0]->dtype, bg)); + new_end.push_back(IntImm(end[0]->dtype, ed)); + new_strides.push_back(IntImm(strides[0]->dtype, st)); + } } + params->begin = new_begin; params->end = new_end; params->strides = new_strides; diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index dd2dc979a731..5f3d754284b5 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -1235,6 +1235,49 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_conv_strided_slice_axes_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 28, 28, 32)) + weight = relay.var("weight", shape=(3, 3, 32, 32)) + y = relay.nn.conv2d( + x, + weight, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.strided_slice(y, begin=[0, 16], end=[1, 33], strides=[1, 1], axes=[0, 3]) + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected(): + x = relay.var("x", shape=(1, 28, 28, 32)) + weight = relay.var("weight", shape=(3, 3, 32, 32)) + weight = relay.layout_transform(weight, "HWIO", "OIHW") + x = relay.layout_transform(x, "NHWC", "NCHW") + y = relay.nn.conv2d( + x, + weight, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + ) + y = relay.strided_slice(y, begin=[0, 16], end=[1, 33], strides=[1, 1], axes=[0, 1]) + + y = relay.layout_transform(y, "NCHW", "NHWC") + y = relay.Function(analysis.free_vars(y), y) + return y + + a = run_opt_pass(before(), transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + def test_conv_roi_pool_convert_layout(): def before(): x = relay.var("x", shape=(1, 64, 56, 56)) @@ -1784,3 +1827,4 @@ def expected(): test_convert_with_config() test_conv_squeeze_convert_layout() test_conv_reduce_convert_layout() + test_conv_strided_slice_axes_convert_layout()