Skip to content

Commit

Permalink
layout transform support complete
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 27, 2021
1 parent 5782b70 commit b8ede24
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 37 deletions.
105 changes: 68 additions & 37 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2457,7 +2457,6 @@ bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& attr
Array<Integer> axes;
if (param->axes) {
axes = param->axes.value();
LOG(INFO) << axes.size() << ", " << begin.size() << ", " << end.size() << ", " << strides.size();
ICHECK(axes.size() == begin.size() && axes.size() == end.size() &&
axes.size() == strides.size())
<< "axes, begin, end, and strides must have the same length";
Expand Down Expand Up @@ -2540,17 +2539,14 @@ Array<Array<Layout>> StridedSliceInferCorrectLayout(const Attrs& attrs,
} else {
if (params->axes) {
auto axes = params->axes.value();
new_begin.resize(axes.size());
new_end.resize(axes.size());
new_strides.resize(axes.size());
Array<Integer> new_axes;

for (size_t i = 0; i < axes.size(); ++i) {
auto old_idx = axes[i];
auto new_idx = new_layout.IndexOf(layout[old_idx]);
new_begin.Set(new_idx, begin[i]);
new_end.Set(new_idx, end[i]);
new_strides.Set(new_idx, strides[i]);
new_begin.push_back(begin[i]);
new_end.push_back(end[i]);
new_strides.push_back(strides[i]);
new_axes.push_back(new_idx);
}
params->axes = new_axes;
Expand Down Expand Up @@ -2592,44 +2588,79 @@ Array<Array<Layout>> StridedSliceInferCorrectLayout(const Attrs& attrs,
layout = new_layout;
}
} else {
for (size_t i = 0; i < begin.size(); i++) {
const LayoutAxis& axis = layout[i];
if (!axis.IsPrimal()) {
// original layout that contains splitted axes is not supported
return {{Layout::Undef()}, {Layout::Undef()}};
}
auto factor = new_layout.FactorOf(axis);
if (factor == -1) {
new_begin.push_back(IntImm(begin[i]->dtype, begin[i]));
new_end.push_back(IntImm(end[i]->dtype, end[i]));
} else {
if (strides.defined() && i < strides.size()) {
auto stride = strides[i];
// arbitrary stride is not supported
if (stride.defined() && stride->value != 1) {
if (params->axes) {
auto axes = params->axes.value();
Array<Integer> new_axes;

for (size_t i = 0; i < axes.size(); ++i) {
auto old_idx = axes[i];
auto new_idx = new_layout.IndexOf(layout[old_idx]);
new_axes.push_back(new_idx);

const LayoutAxis& axis = layout[old_idx];
if (!axis.IsPrimal()) {
// original layout that contains splitted axes is not supported
return {{Layout::Undef()}, {Layout::Undef()}};
}

auto factor = new_layout.FactorOf(axis);

if (factor == -1) {
new_begin.push_back(begin[i]);
new_end.push_back(end[i]);
} else {
int64_t bg = begin[i];
int64_t ed = end[i];
if (bg % factor || ed % factor) {
// transform to original layout
return {{Layout::Undef()}, {Layout::Undef()}};
}
new_begin.push_back(IntImm(begin[0]->dtype, (bg / factor)));
new_end.push_back(IntImm(end[0]->dtype, (ed / factor)));
}
int64_t bg = begin[i].defined() ? begin[i]->value : 0;
int64_t ed;
if (!end[i].defined()) {
ed = shape[i].as<IntImmNode>()->value;
} else if (params->slice_mode == "size") {
if (end[i]->value < 0) {
}
params->axes = new_axes;

} else {
for (size_t i = 0; i < begin.size(); i++) {
const LayoutAxis& axis = layout[i];
if (!axis.IsPrimal()) {
// original layout that contains splitted axes is not supported
return {{Layout::Undef()}, {Layout::Undef()}};
}
auto factor = new_layout.FactorOf(axis);
if (factor == -1) {
new_begin.push_back(IntImm(begin[i]->dtype, begin[i]));
new_end.push_back(IntImm(end[i]->dtype, end[i]));
} else {
if (strides.defined() && i < strides.size()) {
auto stride = strides[i];
// arbitrary stride is not supported
if (stride.defined() && stride->value != 1) {
return {{Layout::Undef()}, {Layout::Undef()}};
}
}
int64_t bg = begin[i].defined() ? begin[i]->value : 0;
int64_t ed;
if (!end[i].defined()) {
ed = shape[i].as<IntImmNode>()->value;
} else if (params->slice_mode == "size") {
if (end[i]->value < 0) {
ed = shape[i].as<IntImmNode>()->value;
} else {
ed = bg + end[i]->value;
}
} else {
ed = bg + end[i]->value;
ed = end[i]->value;
}
} else {
ed = end[i]->value;
}

if (bg % factor || ed % factor) {
// transform to original layout
return {{Layout::Undef()}, {Layout::Undef()}};
if (bg % factor || ed % factor) {
// transform to original layout
return {{Layout::Undef()}, {Layout::Undef()}};
}
new_begin.push_back(IntImm(begin[0]->dtype, (bg / factor)));
new_end.push_back(IntImm(end[0]->dtype, (ed / factor)));
}
new_begin.push_back(IntImm(begin[0]->dtype, (bg / factor)));
new_end.push_back(IntImm(end[0]->dtype, (ed / factor)));
}
}

Expand Down
56 changes: 56 additions & 0 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,61 @@ def expected():
)


@tvm.testing.uses_gpu
def test_alter_layout_strided_slice_axes_nhwc():
"""Test rewriting strided_slice with axes during alter_iop_layout"""

def before():
x = relay.var("x", shape=(1, 28, 28, 32))
weight = relay.var("weight", shape=(3, 3, 32, 32))
y = relay.nn.conv2d(
x,
weight,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)
y = relay.strided_slice(y, begin=[0, 16], end=[1, 32], strides=[1, 1], axes=[0, 3])
y = relay.Function(analysis.free_vars(y), y)
return y

def alter_conv2d(attrs, inputs, tinfos, out_type):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs["data_layout"] = "NHWC4c"
return relay.nn.conv2d(data, weight, **new_attrs)

def expected():
x = relay.var("x", shape=(1, 28, 28, 32))
weight = relay.var("weight", shape=(3, 3, 32, 32))
x = relay.layout_transform(x, "NHWC", "NHWC4c")
y = relay.op.nn.conv2d(
x,
weight,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NHWC4c",
kernel_layout="HWIO",
)
y = relay.strided_slice(y, begin=[0, 4], end=[1, 8], strides=[1, 1], axes=[0, 3])
y = relay.layout_transform(y, "NHWC4c", "NHWC")
y = relay.Function(analysis.free_vars(y), y)
return y

with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = run_opt_pass(before(), transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType())

mod_before = tvm.IRModule()
mod_new = tvm.IRModule()
mod_before["main"] = a
mod_new["main"] = b
assert tvm.ir.structural_equal(mod_before, mod_new)


def test_alter_layout_depthwise_conv2d():
"""Test depthwise_conv2d operator"""

Expand Down Expand Up @@ -1298,3 +1353,4 @@ def expected():
test_alter_layout_nhwc_int8_aarch64()
test_alter_op_with_global_var()
test_alter_op_dense()
test_alter_layout_strided_slice_axes_nhwc()

0 comments on commit b8ede24

Please sign in to comment.