From b8ede24ff987eb152bde7cc15afce004a88aeb5f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 27 May 2021 10:21:06 +0900 Subject: [PATCH] layout transform support complete --- src/relay/op/tensor/transform.cc | 105 ++++++++++++------ .../python/relay/test_pass_alter_op_layout.py | 56 ++++++++++ 2 files changed, 124 insertions(+), 37 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index d156b2a0b7fb..34f375f41870 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2457,7 +2457,6 @@ bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attr Array axes; if (param->axes) { axes = param->axes.value(); - LOG(INFO) << axes.size() << ", " << begin.size() << ", " << end.size() << ", " << strides.size(); ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size()) << "axes, begin, end, and strides must have the same length"; @@ -2540,17 +2539,14 @@ Array> StridedSliceInferCorrectLayout(const Attrs& attrs, } else { if (params->axes) { auto axes = params->axes.value(); - new_begin.resize(axes.size()); - new_end.resize(axes.size()); - new_strides.resize(axes.size()); Array new_axes; for (size_t i = 0; i < axes.size(); ++i) { auto old_idx = axes[i]; auto new_idx = new_layout.IndexOf(layout[old_idx]); - new_begin.Set(new_idx, begin[i]); - new_end.Set(new_idx, end[i]); - new_strides.Set(new_idx, strides[i]); + new_begin.push_back(begin[i]); + new_end.push_back(end[i]); + new_strides.push_back(strides[i]); new_axes.push_back(new_idx); } params->axes = new_axes; @@ -2592,44 +2588,79 @@ Array> StridedSliceInferCorrectLayout(const Attrs& attrs, layout = new_layout; } } else { - for (size_t i = 0; i < begin.size(); i++) { - const LayoutAxis& axis = layout[i]; - if (!axis.IsPrimal()) { - // original layout that contains splitted axes is not supported - return {{Layout::Undef()}, {Layout::Undef()}}; - } - auto factor = new_layout.FactorOf(axis); - if (factor == -1) { - new_begin.push_back(IntImm(begin[i]->dtype, begin[i])); - new_end.push_back(IntImm(end[i]->dtype, end[i])); - } else { - if (strides.defined() && i < strides.size()) { - auto stride = strides[i]; - // arbitrary stride is not supported - if (stride.defined() && stride->value != 1) { + if (params->axes) { + auto axes = params->axes.value(); + Array new_axes; + + for (size_t i = 0; i < axes.size(); ++i) { + auto old_idx = axes[i]; + auto new_idx = new_layout.IndexOf(layout[old_idx]); + new_axes.push_back(new_idx); + + const LayoutAxis& axis = layout[old_idx]; + if (!axis.IsPrimal()) { + // original layout that contains splitted axes is not supported + return {{Layout::Undef()}, {Layout::Undef()}}; + } + + auto factor = new_layout.FactorOf(axis); + + if (factor == -1) { + new_begin.push_back(begin[i]); + new_end.push_back(end[i]); + } else { + int64_t bg = begin[i]; + int64_t ed = end[i]; + if (bg % factor || ed % factor) { + // transform to original layout return {{Layout::Undef()}, {Layout::Undef()}}; } + new_begin.push_back(IntImm(begin[0]->dtype, (bg / factor))); + new_end.push_back(IntImm(end[0]->dtype, (ed / factor))); } - int64_t bg = begin[i].defined() ? begin[i]->value : 0; - int64_t ed; - if (!end[i].defined()) { - ed = shape[i].as()->value; - } else if (params->slice_mode == "size") { - if (end[i]->value < 0) { + } + params->axes = new_axes; + + } else { + for (size_t i = 0; i < begin.size(); i++) { + const LayoutAxis& axis = layout[i]; + if (!axis.IsPrimal()) { + // original layout that contains splitted axes is not supported + return {{Layout::Undef()}, {Layout::Undef()}}; + } + auto factor = new_layout.FactorOf(axis); + if (factor == -1) { + new_begin.push_back(IntImm(begin[i]->dtype, begin[i])); + new_end.push_back(IntImm(end[i]->dtype, end[i])); + } else { + if (strides.defined() && i < strides.size()) { + auto stride = strides[i]; + // arbitrary stride is not supported + if (stride.defined() && stride->value != 1) { + return {{Layout::Undef()}, {Layout::Undef()}}; + } + } + int64_t bg = begin[i].defined() ? begin[i]->value : 0; + int64_t ed; + if (!end[i].defined()) { ed = shape[i].as()->value; + } else if (params->slice_mode == "size") { + if (end[i]->value < 0) { + ed = shape[i].as()->value; + } else { + ed = bg + end[i]->value; + } } else { - ed = bg + end[i]->value; + ed = end[i]->value; } - } else { - ed = end[i]->value; - } - if (bg % factor || ed % factor) { - // transform to original layout - return {{Layout::Undef()}, {Layout::Undef()}}; + if (bg % factor || ed % factor) { + // transform to original layout + return {{Layout::Undef()}, {Layout::Undef()}}; + } + new_begin.push_back(IntImm(begin[0]->dtype, (bg / factor))); + new_end.push_back(IntImm(end[0]->dtype, (ed / factor))); } - new_begin.push_back(IntImm(begin[0]->dtype, (bg / factor))); - new_end.push_back(IntImm(end[0]->dtype, (ed / factor))); } } diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 3031c55379ae..5c2793c607a9 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -770,6 +770,61 @@ def expected(): ) +@tvm.testing.uses_gpu +def test_alter_layout_strided_slice_axes_nhwc(): + """Test rewriting strided_slice with axes during alter_iop_layout""" + + def before(): + x = relay.var("x", shape=(1, 28, 28, 32)) + weight = relay.var("weight", shape=(3, 3, 32, 32)) + y = relay.nn.conv2d( + x, + weight, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.strided_slice(y, begin=[0, 16], end=[1, 32], strides=[1, 1], axes=[0, 3]) + y = relay.Function(analysis.free_vars(y), y) + return y + + def alter_conv2d(attrs, inputs, tinfos, out_type): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs["data_layout"] = "NHWC4c" + return relay.nn.conv2d(data, weight, **new_attrs) + + def expected(): + x = relay.var("x", shape=(1, 28, 28, 32)) + weight = relay.var("weight", shape=(3, 3, 32, 32)) + x = relay.layout_transform(x, "NHWC", "NHWC4c") + y = relay.op.nn.conv2d( + x, + weight, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC4c", + kernel_layout="HWIO", + ) + y = relay.strided_slice(y, begin=[0, 4], end=[1, 8], strides=[1, 1], axes=[0, 3]) + y = relay.layout_transform(y, "NHWC4c", "NHWC") + y = relay.Function(analysis.free_vars(y), y) + return y + + with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d): + a = run_opt_pass(before(), transform.AlterOpLayout()) + b = run_opt_pass(expected(), transform.InferType()) + + mod_before = tvm.IRModule() + mod_new = tvm.IRModule() + mod_before["main"] = a + mod_new["main"] = b + assert tvm.ir.structural_equal(mod_before, mod_new) + + def test_alter_layout_depthwise_conv2d(): """Test depthwise_conv2d operator""" @@ -1298,3 +1353,4 @@ def expected(): test_alter_layout_nhwc_int8_aarch64() test_alter_op_with_global_var() test_alter_op_dense() + test_alter_layout_strided_slice_axes_nhwc()