Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ConvertLayout] Squeeze and reduce ops #7835

Merged
merged 2 commits into from
Apr 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions src/relay/op/tensor/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,11 @@ Array<Array<Layout>> ReduceInferCorrectLayout(const Attrs& attrs,
uint32_t indim = old_in_shapes[0].size();
auto r_axes = GetReduceAxes(indim, params->axis, params->exclude);

Layout ret = Layout::Undef();
if (new_in_layouts.defined() && r_axes.size()) {
// Adapt to new layout. The axis has to change. Record original reduce axes. Convert to the
// modified layout axes.
ICHECK_EQ(new_in_layouts.size(), 1);
ICHECK_EQ(old_in_layouts.size(), 1);
Layout inferred_in = Layout::Undef();
Layout inferred_out = Layout::Undef();

// Infer [in_layout, out_layout, new_r_axes] from old_in_layout or new_in_layout
auto infer = [&](const Layout& layout) {
// 1) Collect the original axes
std::unordered_set<std::string> old_r_dims;
for (auto r_axis : r_axes) {
Expand All @@ -146,31 +144,51 @@ Array<Array<Layout>> ReduceInferCorrectLayout(const Attrs& attrs,

// 2) Collect the new axes by walking new_layout.
tvm::Array<tvm::Integer> new_r_axes;
std::string new_layout_string = "";
std::string inferred_in_string = "";
std::string inferred_out_string = "";
int axis_index = 0;
for (auto iter_var : new_in_layouts[0]->axes) {
for (auto iter_var : layout->axes) {
const auto& layout_axis = LayoutAxis::Get(iter_var);
const std::string& layout_dim = layout_axis.name();
if (old_r_dims.count(layout_dim)) {
new_r_axes.push_back(tvm::Integer(axis_index));
}
// Collect only the primal axis.
if (layout_axis.IsPrimal()) {
new_layout_string += layout_dim;
if (!old_r_dims.count(layout_dim) || params->keepdims) {
inferred_out_string += layout_dim;
}
inferred_in_string += layout_dim;
axis_index++;
}
}

// 3) Set the new axis and layout.
ret = Layout(new_layout_string);
return std::make_tuple(Layout(inferred_in_string), Layout(inferred_out_string), new_r_axes);
};

std::string new_layout_string;
Array<Integer> new_r_axes;

if (new_in_layouts.defined() && r_axes.size()) {
// Adapt to new layout. The axis has to change. Record original reduce axes. Convert to the
// modified layout axes.
ICHECK_EQ(new_in_layouts.size(), 1);
ICHECK_EQ(old_in_layouts.size(), 1);

// Get inferred_in and inferred_out from new_in_layout.
std::tie(inferred_in, inferred_out, new_r_axes) = infer(new_in_layouts[0]);
params->axis = new_r_axes;
} else if (old_in_layouts.defined()) {
// If the new layout is undefined, set the old layout as the inferred layout.
ICHECK_EQ(old_in_layouts.size(), 1);
ret = old_in_layouts[0];

// If the new layout is undefined, get inferred_in and inferred_out from old_in_layout.
if (old_in_layouts[0].defined()) {
std::tie(inferred_in, inferred_out, std::ignore) = infer(old_in_layouts[0]);
}
}

return Array<Array<Layout>>{{ret}, {ret}};
return Array<Array<Layout>>{{inferred_in}, {inferred_out}};
}

template <typename F>
Expand Down
66 changes: 65 additions & 1 deletion src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2159,6 +2159,69 @@ Array<te::Tensor> SqueezeCompute(const Attrs& attrs, const Array<te::Tensor>& in
return {topi::squeeze(inputs[0], param->axis)};
}

Array<Array<Layout>> SqueezeInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
// NOTE: Discard "const" qualifier here.
SqueezeAttrs* params = const_cast<SqueezeAttrs*>(attrs.as<SqueezeAttrs>());

Layout inferred_input = new_in_layouts.defined() ? new_in_layouts[0] : old_in_layouts[0];
Layout inferred_output = inferred_input;

ICHECK(old_in_types[0].as<TensorTypeNode>());
const auto& shape = old_in_types[0].as<TensorTypeNode>()->shape;

// axis to squeeze
Array<Integer> axis;
if (params->axis.defined()) {
axis = params->axis;
} else {
// if axes is None, squeeze all axes of dimension 1
for (size_t i = 0; i < shape.size(); i++) {
if (topi::detail::GetConstInt(shape[i]) == 1) {
axis.push_back(i);
}
}
}

// If new_in_layouts are defined, this code tries to modify the layout
if (new_in_layouts.defined() && old_in_layouts.defined()) {
Array<Integer> new_axis;
for (const auto& e : axis) {
const auto& dim = old_in_layouts[0][e];
new_axis.push_back((new_in_layouts[0]).IndexOf(dim));
}
params->axis = new_axis;
axis = new_axis;
}

// Infer output layout
Array<tir::IterVar> kept_axes;
for (size_t i = 0; i < inferred_input.ndim(); i++) {
bool is_dim_kept = true;

// Check whether the dim should be kept
for (const auto& e : axis) {
int64_t axis_val = e->value;
if (axis_val < 0) {
axis_val += inferred_input.ndim();
}
if (static_cast<int64_t>(i) == axis_val) {
is_dim_kept = false;
break;
}
}

if (is_dim_kept) {
kept_axes.push_back(inferred_input->axes[i]);
}
}
inferred_output = Layout(kept_axes);

return Array<Array<Layout>>{{inferred_input}, {inferred_output}};
}

RELAY_REGISTER_OP("squeeze")
.describe(R"code(Squeeze the input tensor at the dimensions given by axes

Expand All @@ -2171,7 +2234,8 @@ RELAY_REGISTER_OP("squeeze")
.set_support_level(3)
.add_type_rel("Squeeze", SqueezeRel)
.set_attr<FTVMCompute>("FTVMCompute", SqueezeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", SqueezeInferCorrectLayout);

// CollapseSumLike: <A, B> -> B where BroadCast(A, B) = A
bool CollapseSumLikeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Expand Down
14 changes: 11 additions & 3 deletions src/relay/transforms/convert_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,17 @@ class ConvertTransformMemorizer : public TransformMemorizer {
auto desired_layouts = operator->()->desired_layouts_;
if (desired_layouts.find(op->name) != desired_layouts.end()) {
tvm::Array<tvm::te::Tensor> tinfos;
for (auto expr : ref_call->args) {
auto ttype = expr->type_as<TensorTypeNode>();
tinfos.push_back(tvm::te::placeholder(ttype->shape, ttype->dtype));
for (auto& expr : ref_call->args) {
if (expr->checked_type()->IsInstance<TupleTypeNode>()) {
auto tuple_ttype_node = expr->type_as<TupleTypeNode>();
for (auto& ttype : tuple_ttype_node->fields) {
auto ttype_node = ttype.as<TensorTypeNode>();
tinfos.push_back(tvm::te::placeholder(ttype_node->shape, ttype_node->dtype));
}
} else {
auto ttype = expr->type_as<TensorTypeNode>();
tinfos.push_back(tvm::te::placeholder(ttype->shape, ttype->dtype));
}
}

Array<String> op_desired_layouts = desired_layouts.at(op->name);
Expand Down
187 changes: 187 additions & 0 deletions tests/python/relay/test_pass_convert_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,6 +1556,191 @@ def expected():
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


def test_conv_squeeze_convert_layout():
def _test_conv_squeeze_convert_layout1():
# specified axis is squeezed
def before():
x = relay.var("x", shape=(1, 1, 1, 2048))
weight = relay.var("weight", shape=(1, 1, 2048, 1000))
y = relay.nn.conv2d(
x,
weight,
channels=1000,
kernel_size=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)
y = relay.nn.relu(y)
y = relay.squeeze(y, axis=[-3])
return relay.Function(analysis.free_vars(y), y)

def expected():
x = relay.var("x", shape=(1, 1, 1, 2048))
weight = relay.var("weight", shape=(1, 1, 2048, 1000))
weight = relay.layout_transform(weight, "HWIO", "OIHW")
x = relay.layout_transform(x, "NHWC", "NCHW")
y = relay.nn.conv2d(x, weight, channels=1000, kernel_size=(1, 1))
y = relay.nn.relu(y)
y = relay.squeeze(y, axis=[2])
y = relay.layout_transform(y, "NCW", "NWC")
return relay.Function(analysis.free_vars(y), y)

a = before()
a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)

def _test_conv_squeeze_convert_layout2():
# all axes of dimension 1 are squeezed
def before():
x = relay.var("x", shape=(1, 1, 1, 2048))
weight = relay.var("weight", shape=(1, 1, 2048, 1000))
y = relay.nn.conv2d(
x,
weight,
channels=1000,
kernel_size=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)
y = relay.nn.relu(y)
y = relay.squeeze(y)
return relay.Function(analysis.free_vars(y), y)

def expected():
x = relay.var("x", shape=(1, 1, 1, 2048))
weight = relay.var("weight", shape=(1, 1, 2048, 1000))
weight = relay.layout_transform(weight, "HWIO", "OIHW")
x = relay.layout_transform(x, "NHWC", "NCHW")
y = relay.nn.conv2d(x, weight, channels=1000, kernel_size=(1, 1))
y = relay.nn.relu(y)
y = relay.squeeze(y, [0, 2, 3])
return relay.Function(analysis.free_vars(y), y)

a = before()
a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)

def _test_conv_squeeze_convert_layout3():
# squeeze axis is empty
def before():
x = relay.var("x", shape=(1, 1, 1, 2048))
weight = relay.var("weight", shape=(1, 1, 2048, 1000))
y = relay.nn.conv2d(
x,
weight,
channels=1000,
kernel_size=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)
y = relay.nn.relu(y)
y = relay.squeeze(y, axis=[])
return relay.Function(analysis.free_vars(y), y)

def expected():
x = relay.var("x", shape=(1, 1, 1, 2048))
weight = relay.var("weight", shape=(1, 1, 2048, 1000))
weight = relay.layout_transform(weight, "HWIO", "OIHW")
x = relay.layout_transform(x, "NHWC", "NCHW")
y = relay.nn.conv2d(x, weight, channels=1000, kernel_size=(1, 1))
y = relay.nn.relu(y)
y = relay.squeeze(y, axis=[])
y = relay.layout_transform(y, "NCHW", "NHWC")
return relay.Function(analysis.free_vars(y), y)

a = before()
a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)

_test_conv_squeeze_convert_layout1()
_test_conv_squeeze_convert_layout2()
_test_conv_squeeze_convert_layout3()


def test_conv_reduce_convert_layout():
def _test_conv_reduce_convert_layout1():
def before():
x = relay.var("x", shape=(1, 1, 1, 2048))
weight = relay.var("weight", shape=(1, 1, 2048, 1000))
y = relay.nn.conv2d(
x,
weight,
channels=1000,
kernel_size=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)
y = relay.nn.relu(y)
y = relay.sum(y, axis=(1, 2))
y = relay.sum(y, axis=(1,))
y = relay.sum(y)
y = relay.sum(y)
return relay.Function(analysis.free_vars(y), y)

def expected():
x = relay.var("x", shape=(1, 1, 1, 2048))
weight = relay.var("weight", shape=(1, 1, 2048, 1000))
weight = relay.layout_transform(weight, "HWIO", "OIHW")
x = relay.layout_transform(x, "NHWC", "NCHW")
y = relay.nn.conv2d(x, weight, channels=1000, kernel_size=(1, 1))
y = relay.nn.relu(y)
y = relay.sum(y, axis=(2, 3))
y = relay.sum(y, axis=(1,))
y = relay.sum(y)
y = relay.sum(y)
return relay.Function(analysis.free_vars(y), y)

a = before()
a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)

def _test_conv_reduce_convert_layout2():
def before():
x = relay.var("x", shape=(1, 38, 38, 512))
weight = relay.var("weight", shape=(3, 3, 512, 512))
y = relay.nn.conv2d(
x,
weight,
channels=512,
kernel_size=(3, 3),
data_layout="NHWC",
kernel_layout="HWIO",
)
y = relay.nn.relu(y)
y = relay.multiply(y, y)
y = relay.sum(y, axis=(3,), keepdims=True)
return relay.Function(analysis.free_vars(y), y)

def expected():
x = relay.var("x", shape=(1, 38, 38, 512))
weight = relay.var("weight", shape=(3, 3, 512, 512))
weight = relay.layout_transform(weight, "HWIO", "OIHW")
x = relay.layout_transform(x, "NHWC", "NCHW")
y = relay.nn.conv2d(x, weight, channels=512, kernel_size=(3, 3))
y = relay.nn.relu(y)
y = relay.multiply(y, y)
y = relay.sum(y, axis=(1,), keepdims=True)
y = relay.layout_transform(y, "NCHW", "NHWC")
return relay.Function(analysis.free_vars(y), y)

a = before()
a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)

_test_conv_reduce_convert_layout1()
_test_conv_reduce_convert_layout2()


if __name__ == "__main__":
test_qnn_binary_no_convert_layout()
test_no_convert_layout()
Expand Down Expand Up @@ -1584,3 +1769,5 @@ def expected():
test_different_ops_convert_layout()
test_no_desired_layout()
test_convert_with_config()
test_conv_squeeze_convert_layout()
test_conv_reduce_convert_layout()