Skip to content

Commit

Permalink
Register layout conversion function to more reduce ops
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Sep 21, 2021
1 parent 18a36a7 commit 84353c5
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 38 deletions.
19 changes: 14 additions & 5 deletions src/relay/op/tensor/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,14 @@ Array<Integer> GetExcludeAxes(size_t indim, const Array<Integer>& inaxis) {
}

// Return the modified layout for AlterOpLayout pass.
template <typename T>
InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
const auto* attrs_ptr = attrs.as<ReduceAttrs>();
const auto* attrs_ptr = attrs.as<T>();
ICHECK(attrs_ptr);
ObjectPtr<ReduceAttrs> params = make_object<ReduceAttrs>(*attrs_ptr);
ObjectPtr<T> params = make_object<T>(*attrs_ptr);

// Get the reduce axes.
Array<Array<IndexExpr>> old_in_shapes;
Expand Down Expand Up @@ -389,6 +390,7 @@ values over a given axis.
.set_support_level(4)
.add_type_rel("ArgReduce", GenericReduceRel<ArgReduceAttrs>)
.set_attr<FTVMCompute>("FTVMCompute", ArgMaxCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ArgReduceAttrs>)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

Array<te::Tensor> ArgMinCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
Expand All @@ -405,6 +407,7 @@ values over a given axis.
.set_support_level(4)
.add_type_rel("ArgReduce", GenericReduceRel<ArgReduceAttrs>)
.set_attr<FTVMCompute>("FTVMCompute", ArgMinCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ArgReduceAttrs>)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

Array<te::Tensor> SumCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
Expand Down Expand Up @@ -433,7 +436,7 @@ Example::
.set_attrs_type<ReduceAttrs>()
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
.set_attr<FTVMCompute>("FTVMCompute", SumCompute)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

Expand Down Expand Up @@ -468,6 +471,7 @@ Example::
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", AllCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

Array<te::Tensor> AnyCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
Expand Down Expand Up @@ -516,6 +520,7 @@ RELAY_REGISTER_REDUCE_OP("max")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", MaxCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

Array<te::Tensor> MinCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
Expand All @@ -531,6 +536,7 @@ RELAY_REGISTER_REDUCE_OP("min")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", MinCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

Array<te::Tensor> ProdCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
Expand All @@ -551,17 +557,18 @@ Example::
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]
mean(data, axis=1)
prod(data, axis=1)
[35562240]
mean(data, axis=[1,2])
prod(data, axis=[1,2])
[ 36 480 2058]
)code" TVM_ADD_FILELINE)
.set_attrs_type<ReduceAttrs>()
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", ProdCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

Array<te::Tensor> MeanCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
Expand Down Expand Up @@ -600,6 +607,7 @@ Example::
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel)
.set_attr<FTVMCompute>("FTVMCompute", MeanCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

bool VarianceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Expand Down Expand Up @@ -675,6 +683,7 @@ RELAY_REGISTER_OP("variance")
.add_argument("mean", "Tensor", "The mean tensor.")
.add_type_rel("Variance", VarianceRel)
.set_attr<FTVMCompute>("FTVMCompute", VarianceCompute)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
.set_attr<TOpPattern>("TOpPattern", kCommReduce);

} // namespace relay
Expand Down
80 changes: 47 additions & 33 deletions tests/python/relay/test_pass_convert_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""Test alter op layout pass"""
import pytest

import tvm
from tvm import te

Expand Down Expand Up @@ -1925,37 +1927,49 @@ def infer_correct_layout_relu(attrs, new_in_layouts, old_in_layouts, old_in_type
assert test_infer_correct_layout_flag == True


def test_reduce_op_convert_layout():
for reduce_op in [relay.argmax, relay.mean, relay.max]:

def before():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var("weight", shape=(64, 64, 3, 3))
y = relay.nn.conv2d(
x,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
)
y = reduce_op(y, axis=[2, 3])
y = relay.Function([x, weight], y)
return y

def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var("weight", shape=(64, 64, 3, 3))
x = relay.layout_transform(x, "NCHW", "NHWC")
weight = relay.layout_transform(weight, "OIHW", "HWIO")
y = relay.nn.conv2d(
x,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)
y = reduce_op(y, axis=[1, 2])
y = relay.Function(relay.analysis.free_vars(y), y)
return y

a = before()
a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NHWC", "default"]}))
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


if __name__ == "__main__":
test_qnn_binary_no_convert_layout()
test_no_convert_layout()
test_conv_convert_layout()
test_conv_nhwc_convert_layout()
test_conv_bias_pool_convert_layout()
test_conv_concat_convert_layout()
test_dual_path_convert_layout()
test_bn_convert_layout()
test_slice_like_convert_layout()
test_transpose_convert_layout()
test_resnet_convert_layout()
test_scalar_convert_layout()
test_conv_bn_convert_layout()
test_qnn_conv_requantize_convert_layout()
test_qnn_conv_concat_convert_layout()
test_qnn_conv_add_convert_layout()
test_qnn_conv_nhwc_convert_layout()
test_conv_convert_kernel_layout()
test_conv_transpose_convert_layout()
test_conv_roi_align_convert_layout()
test_conv_roi_pool_convert_layout()
test_conv_strided_slice_convert_layout()
test_deformable_conv_bias_pool_convert_layout()
test_default_keyword()
test_different_ops_convert_layout()
test_no_desired_layout()
test_convert_with_config()
test_conv_squeeze_convert_layout()
test_conv_reduce_convert_layout()
test_conv_strided_slice_axes_convert_layout()
test_image_resize_convert_layout()
test_conv_image_resize_convert_layout()
test_infer_correct_layout()
pytest.main([__file__])

0 comments on commit 84353c5

Please sign in to comment.