Skip to content

Commit

Permalink
support layout transform part1
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 27, 2021
1 parent e94aa6b commit 5782b70
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 24 deletions.
69 changes: 45 additions & 24 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2457,6 +2457,7 @@ bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& attr
Array<Integer> axes;
if (param->axes) {
axes = param->axes.value();
LOG(INFO) << axes.size() << ", " << begin.size() << ", " << end.size() << ", " << strides.size();
ICHECK(axes.size() == begin.size() && axes.size() == end.size() &&
axes.size() == strides.size())
<< "axes, begin, end, and strides must have the same length";
Expand Down Expand Up @@ -2537,34 +2538,54 @@ Array<Array<Layout>> StridedSliceInferCorrectLayout(const Attrs& attrs,
// Not support NHW4c -> NCHW
return {{Layout::Undef()}, {Layout::Undef()}};
} else {
for (size_t i = 0; i < new_layout_name.size(); ++i) {
auto index = layout.IndexOf(new_layout[i]);
if (index == -1) {
return {{Layout::Undef()}, {Layout::Undef()}};
if (params->axes) {
auto axes = params->axes.value();
new_begin.resize(axes.size());
new_end.resize(axes.size());
new_strides.resize(axes.size());
Array<Integer> new_axes;

for (size_t i = 0; i < axes.size(); ++i) {
auto old_idx = axes[i];
auto new_idx = new_layout.IndexOf(layout[old_idx]);
new_begin.Set(new_idx, begin[i]);
new_end.Set(new_idx, end[i]);
new_strides.Set(new_idx, strides[i]);
new_axes.push_back(new_idx);
}
params->axes = new_axes;

size_t new_index = static_cast<size_t>(index);
int64_t bg, ed, st;
if (strides.defined() && new_index < strides.size() && strides[new_index].defined()) {
st = strides[new_index]->value;
} else {
st = 1;
}
if (new_index < begin.size() && begin[new_index].defined()) {
bg = begin[new_index]->value;
} else {
bg = 0;
}
if (new_index < end.size() && end[new_index].defined()) {
ed = end[new_index]->value;
} else {
ed = shape[new_index].as<IntImmNode>()->value;
}
} else {
for (size_t i = 0; i < new_layout_name.size(); ++i) {
auto index = layout.IndexOf(new_layout[i]);
if (index == -1) {
return {{Layout::Undef()}, {Layout::Undef()}};
}

new_begin.push_back(IntImm(begin[0]->dtype, bg));
new_end.push_back(IntImm(end[0]->dtype, ed));
new_strides.push_back(IntImm(strides[0]->dtype, st));
size_t new_index = static_cast<size_t>(index);
int64_t bg, ed, st;
if (strides.defined() && new_index < strides.size() && strides[new_index].defined()) {
st = strides[new_index]->value;
} else {
st = 1;
}
if (new_index < begin.size() && begin[new_index].defined()) {
bg = begin[new_index]->value;
} else {
bg = 0;
}
if (new_index < end.size() && end[new_index].defined()) {
ed = end[new_index]->value;
} else {
ed = shape[new_index].as<IntImmNode>()->value;
}

new_begin.push_back(IntImm(begin[0]->dtype, bg));
new_end.push_back(IntImm(end[0]->dtype, ed));
new_strides.push_back(IntImm(strides[0]->dtype, st));
}
}

params->begin = new_begin;
params->end = new_end;
params->strides = new_strides;
Expand Down
44 changes: 44 additions & 0 deletions tests/python/relay/test_pass_convert_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,6 +1235,49 @@ def expected():
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


def test_conv_strided_slice_axes_convert_layout():
def before():
x = relay.var("x", shape=(1, 28, 28, 32))
weight = relay.var("weight", shape=(3, 3, 32, 32))
y = relay.nn.conv2d(
x,
weight,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)
y = relay.strided_slice(y, begin=[0, 16], end=[1, 33], strides=[1, 1], axes=[0, 3])
y = relay.Function(analysis.free_vars(y), y)
return y

def expected():
x = relay.var("x", shape=(1, 28, 28, 32))
weight = relay.var("weight", shape=(3, 3, 32, 32))
weight = relay.layout_transform(weight, "HWIO", "OIHW")
x = relay.layout_transform(x, "NHWC", "NCHW")
y = relay.nn.conv2d(
x,
weight,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
)
y = relay.strided_slice(y, begin=[0, 16], end=[1, 33], strides=[1, 1], axes=[0, 1])

y = relay.layout_transform(y, "NCHW", "NHWC")
y = relay.Function(analysis.free_vars(y), y)
return y

a = run_opt_pass(before(), transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


def test_conv_roi_pool_convert_layout():
def before():
x = relay.var("x", shape=(1, 64, 56, 56))
Expand Down Expand Up @@ -1784,3 +1827,4 @@ def expected():
test_convert_with_config()
test_conv_squeeze_convert_layout()
test_conv_reduce_convert_layout()
test_conv_strided_slice_axes_convert_layout()

0 comments on commit 5782b70

Please sign in to comment.