From a1d12cc2cb7d90a2844da7536e0cde9b44c7969f Mon Sep 17 00:00:00 2001 From: shoubhik Date: Mon, 11 May 2020 16:27:23 -0700 Subject: [PATCH] [ConvertLayout] Support QNN ops. #5066 https://github.com/apache/incubator-tvm/pull/5066 --- python/tvm/relay/op/nn/_nn.py | 10 +-- python/tvm/relay/qnn/op/__init__.py | 2 +- python/tvm/relay/qnn/op/layout_conversions.py | 51 ++++++++++++ src/relay/op/nn/bitserial.cc | 2 +- src/relay/op/nn/convolution.cc | 14 ---- src/relay/op/nn/convolution.h | 16 ++++ src/relay/op/nn/nn.cc | 12 ++- src/relay/op/nn/pad.cc | 2 +- src/relay/op/nn/pooling.cc | 2 +- src/relay/op/nn/upsampling.cc | 2 +- src/relay/op/tensor/reduce.cc | 7 +- src/relay/op/tensor/transform.cc | 55 ++----------- src/relay/op/tensor/transform.h | 59 ++++++++++++++ src/relay/pass/infer_layout_util.h | 17 ++-- src/relay/pass/transform_layout.h | 15 +--- src/relay/qnn/op/add.cc | 21 ++++- src/relay/qnn/op/concatenate.cc | 41 +++++++++- src/relay/qnn/op/convolution.cc | 21 ++++- src/relay/qnn/op/requantize.cc | 77 ++++++++++++++++++- 19 files changed, 324 insertions(+), 102 deletions(-) create mode 100644 python/tvm/relay/qnn/op/layout_conversions.py diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index ff8d54b43d08..6d420793a69d 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -268,8 +268,6 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layout): """ from tvm import relay - data_layout = attrs['data_layout'] - kernel_layout = attrs['kernel_layout'] data, weight = inputs assert desired_layout == 'NCHW', \ "Currently only transformation to NCHW layout is supported." @@ -277,13 +275,7 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layout): new_attrs = dict(attrs) new_attrs['data_layout'] = desired_layout new_attrs['kernel_layout'] = 'OIHW' - - if data_layout == 'NHWC' and kernel_layout == 'HWIO': - # Convert (NHWC, HWIO) to (NCHW, OIHW) - return relay.nn.conv2d(data, weight, **new_attrs) - if data_layout == 'NHWC' and kernel_layout == 'HWOI': - # Convert (NHWC, HWOI) to (NCHW, OIHW). Depthwise conv2d. - return relay.nn.conv2d(data, weight, **new_attrs) + return relay.nn.conv2d(data, weight, **new_attrs) return None reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/python/tvm/relay/qnn/op/__init__.py b/python/tvm/relay/qnn/op/__init__.py index 042dcb9d1893..6d66e12eeafc 100644 --- a/python/tvm/relay/qnn/op/__init__.py +++ b/python/tvm/relay/qnn/op/__init__.py @@ -19,4 +19,4 @@ from __future__ import absolute_import as _abs from .qnn import * from .op import register_qnn_legalize -from . import legalizations +from . import legalizations, layout_conversions diff --git a/python/tvm/relay/qnn/op/layout_conversions.py b/python/tvm/relay/qnn/op/layout_conversions.py new file mode 100644 index 000000000000..64912da0c61c --- /dev/null +++ b/python/tvm/relay/qnn/op/layout_conversions.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""Convert layout related registration""" +from __future__ import absolute_import + +from tvm.relay.op import op as reg + + +@reg.register_convert_op_layout("qnn.conv2d") +def convert_qnn_conv2d(attrs, inputs, tinfos, desired_layout): + """Convert Layout pass registration for QNN conv2d op. + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + tinfos : list of types + List of input and output types + desired_layout : str + The desired layout + Returns + ------- + result : tvm.relay.Expr + The transformed expr + """ + # pylint: disable=import-outside-toplevel + from tvm import relay + assert desired_layout == 'NCHW', \ + "Currently only transformation to NCHW layout is supported." + if desired_layout == 'NCHW': + new_attrs = dict(attrs) + new_attrs['data_layout'] = desired_layout + new_attrs['kernel_layout'] = 'OIHW' + return relay.qnn.op.conv2d(*inputs, **new_attrs) + return None \ No newline at end of file diff --git a/src/relay/op/nn/bitserial.cc b/src/relay/op/nn/bitserial.cc index d651baeccb4c..aed3c136c184 100644 --- a/src/relay/op/nn/bitserial.cc +++ b/src/relay/op/nn/bitserial.cc @@ -38,7 +38,7 @@ template Array> BinaryConv2DInferCorrectLayout(const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, - const Array>& old_in_shapes) { + const Array& old_in_types) { const T* params = attrs.as(); // We always make other operators to fit the layouts of convolution layers diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index f082b83a4bd2..72fc328c7490 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -37,20 +37,6 @@ namespace relay { // relay.nn.conv2d TVM_REGISTER_NODE_TYPE(Conv2DAttrs); -template -Array > Conv2DInferCorrectLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array> &old_in_shapes) { - const T* params = attrs.as(); - - // We always make other operators to fit the layouts of convolution layers - // So this inference ignores all inputs - return Array >{{params->data_layout, params->kernel_layout}, - {params->out_layout == "" ? - params->data_layout : params->out_layout}}; -} // Positional relay function to create conv2d operator // used by frontend FFI. diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index 01437302bc92..f94ff7ab7de9 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -138,6 +138,22 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } +template +Array > Conv2DInferCorrectLayout( + const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array &old_in_types) { + const T* params = attrs.as(); + + // We always make other operators to fit the layouts of convolution layers + // So this inference ignores all inputs + return Array >{{params->data_layout, params->kernel_layout}, + {params->out_layout == "" ? + params->data_layout : params->out_layout}}; +} + + } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_NN_CONVOLUTION_H_ diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 3ba31291d275..b635e02b1fa2 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -271,10 +271,10 @@ Array > PReluInferCorrectLayout( const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, - const Array> &old_in_shapes) { + const Array &old_in_types) { CHECK_EQ(old_in_layouts.size(), 2U); - CHECK_EQ(old_in_shapes.size(), 2U); + CHECK_EQ(old_in_types.size(), 2U); Layout data_layout = old_in_layouts[0]; if (new_in_layouts.defined()) { CHECK_EQ(new_in_layouts.size(), 2U); @@ -619,9 +619,15 @@ TVM_REGISTER_NODE_TYPE(BatchNormAttrs); Array> BatchNormInferCorrectLayout(const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, - const Array>& old_in_shapes) { + const Array& old_in_types) { BatchNormAttrs* param = const_cast(attrs.as()); + Array> old_in_shapes; + for (auto old_in_t : old_in_types) { + CHECK(old_in_t.as()); + old_in_shapes.push_back(old_in_t.as()->shape); + } + size_t axis = param->axis < 0 ? param->axis + old_in_shapes[0].size() : static_cast(param->axis); diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index 44bb287f2ee8..57265581958b 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -39,7 +39,7 @@ Array > PadInferCorrectLayout( const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, - const Array> &old_in_shapes) { + const Array &old_in_types) { // NOTE: Discard "const" qualifier here. PadAttrs *params = const_cast(attrs.as()); diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index b8873181e960..b8810b0b3bdf 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -41,7 +41,7 @@ Array > Pool2DInferCorrectLayout( const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, - const Array> &old_in_shapes) { + const Array &old_in_types) { // NOTE: Discard "const" qualifier here. T *params = const_cast(attrs.as()); diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc index 6cdf6fc0c7b5..2ac7a41728bb 100644 --- a/src/relay/op/nn/upsampling.cc +++ b/src/relay/op/nn/upsampling.cc @@ -39,7 +39,7 @@ Array > UpsamplingInferCorrectLayout( const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, - const Array> &old_in_shapes) { + const Array &old_in_types) { // NOTE: Discard "const" qualifier here. T *params = const_cast(attrs.as()); diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 48f35ea49d07..0e8dcdb4d4de 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -122,11 +122,16 @@ Array GetExcludeAxes(size_t indim, Array> ReduceInferCorrectLayout(const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, - const Array>& old_in_shapes) { + const Array& old_in_types) { // NOTE: Discard "const" qualifier here. ReduceAttrs* params = const_cast(attrs.as()); // Get the reduce axes. + Array> old_in_shapes; + for (auto old_in_t : old_in_types) { + CHECK(old_in_t.as()); + old_in_shapes.push_back(old_in_t.as()->shape); + } uint32_t indim = old_in_shapes[0].size(); auto r_axes = GetReduceAxes(indim, params->axis, params->exclude); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 46d7e5a6fe11..d33898525d77 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -278,54 +278,6 @@ Array ConcatenateCompute(const Attrs& attrs, return { topi::concatenate(inputs, param->axis) }; } -Array> ConcatenateLayout( - const Attrs& attrs, - const Array& new_in_layouts, - const Array& old_in_layouts, - const Array> &old_in_shapes) { - ConcatenateAttrs* param = const_cast(attrs.as()); - - size_t axis = param->axis < 0 ? param->axis + old_in_shapes[0].size() : - static_cast(param->axis); - - Layout ret; - bool is_new_layout_selected = false; - if (new_in_layouts.defined()) { // this function is called after some operators are alternated. - // If all the new input layouts are same, the new in layout gets selected. For axis, the new - // axis in the new layout is identified. The param->axis is then modified on the fly to conform - // to the new input layout. - const auto& concate_dim = old_in_layouts[0][axis]; - bool all_input_layouts_same = true; - for (auto new_layout : new_in_layouts) { - if (!new_layout.Equals(new_in_layouts[0])) { - all_input_layouts_same = false; - } - } - if (all_input_layouts_same) { - auto new_index = new_in_layouts[0].IndexOf(concate_dim); - ret = new_in_layouts[0]; - param->axis = new_index; - is_new_layout_selected = true; - } - } - - if (!is_new_layout_selected) { - // this function is called on the original correct relay ir - for (size_t i = 0; i < old_in_layouts.size(); ++i) { - if (old_in_layouts[i].defined()) { - ret = old_in_layouts[i]; - break; - } - } - - if (ret.ndim() <= axis || !ret[axis].IsPrimal()) { - return Array > {{Layout::Undef()}, {Layout::Undef()}}; - } - } - - return Array > {Array(old_in_layouts.size(), ret), {ret}}; -} - Expr MakeConcatenate(Expr data, int axis) { auto attrs = make_node(); @@ -1933,9 +1885,14 @@ Array > StridedSliceInferCorrectLayout( const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, - const Array>& old_in_shapes) { + const Array& old_in_types) { CHECK(old_in_layouts.defined()); CHECK_EQ(old_in_layouts.size(), 1); + Array> old_in_shapes; + for (auto old_in_t : old_in_types) { + CHECK(old_in_t.as()); + old_in_shapes.push_back(old_in_t.as()->shape); + } CHECK(old_in_shapes.defined()); CHECK_EQ(old_in_shapes.size(), 1); diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index a702db898e4e..c53006322796 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -24,6 +24,8 @@ #ifndef TVM_RELAY_OP_TENSOR_TRANSFORM_H_ #define TVM_RELAY_OP_TENSOR_TRANSFORM_H_ +#include +#include #include #include #include @@ -123,6 +125,63 @@ bool ConcatenateRel(const Array& types, return true; } +static inline Array> ConcatenateLayout( + const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array &old_in_types) { + ConcatenateAttrs* param = const_cast(attrs.as()); + + Array> old_in_shapes; + CHECK_EQ(old_in_types.size(), 1); + for (auto old_in_tuple_t : old_in_types) { + CHECK(old_in_tuple_t.as()); + for (auto old_in_t : old_in_tuple_t.as()->fields) { + old_in_shapes.push_back(old_in_t.as()->shape); + } + } + + size_t axis = param->axis < 0 ? param->axis + old_in_shapes[0].size() : + static_cast(param->axis); + + Layout ret; + bool is_new_layout_selected = false; + if (new_in_layouts.defined()) { // this function is called after some operators are alternated. + // If all the new input layouts are same, the new in layout gets selected. For axis, the new + // axis in the new layout is identified. The param->axis is then modified on the fly to conform + // to the new input layout. + const auto& concate_dim = old_in_layouts[0][axis]; + bool all_input_layouts_same = true; + for (auto new_layout : new_in_layouts) { + if (!new_layout.Equals(new_in_layouts[0])) { + all_input_layouts_same = false; + } + } + if (all_input_layouts_same) { + auto new_index = new_in_layouts[0].IndexOf(concate_dim); + ret = new_in_layouts[0]; + param->axis = new_index; + is_new_layout_selected = true; + } + } + + if (!is_new_layout_selected) { + // this function is called on the original correct relay ir + for (size_t i = 0; i < old_in_layouts.size(); ++i) { + if (old_in_layouts[i].defined()) { + ret = old_in_layouts[i]; + break; + } + } + + if (ret.ndim() <= axis || !ret[axis].IsPrimal()) { + return Array > {{Layout::Undef()}, {Layout::Undef()}}; + } + } + + return Array > {Array(old_in_layouts.size(), ret), {ret}}; +} + } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_TENSOR_TRANSFORM_H_ diff --git a/src/relay/pass/infer_layout_util.h b/src/relay/pass/infer_layout_util.h index 94eeba101fc2..529912345984 100644 --- a/src/relay/pass/infer_layout_util.h +++ b/src/relay/pass/infer_layout_util.h @@ -89,7 +89,7 @@ inline Layout AdjustSubordinateFactors(const Layout& src_layout, const Layout& o * This can be undefined, which means we call this function before alternating * any operators. * \param old_in_layouts The layouts of input arguments before alter_op_layout. - * \param old_in_shapes The shapes of old input arguments. + * \param old_in_types The types of old input arguments. * \return infered_layout An array of two elements that are inferred input layouts and * inferred output layouts. */ @@ -97,13 +97,13 @@ using FInferCorrectLayout = runtime::TypedPackedFunc< Array>(const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, - const Array> &old_in_shapes)>; + const Array &old_in_types)>; /*! \brief take arbitrary input layout and copy to output */ inline Array > ElemwiseArbitraryLayout(const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, - const Array> &old_in_shapes) { + const Array &old_in_types) { Layout ret; if (new_in_layouts.defined()) { @@ -125,8 +125,13 @@ inline Array > ElemwiseArbitraryLayout(const Attrs& attrs, inline Array > BinaryBroadcastLayout(const Attrs& attrs, const Array& new_in_layouts, const Array& old_in_layouts, - const Array> &old_in_shapes) { + const Array &old_in_types) { Array layouts; + Array> old_in_shapes; + for (auto old_in_t : old_in_types) { + CHECK(old_in_t.as()); + old_in_shapes.push_back(old_in_t.as()->shape); + } if (new_in_layouts.defined()) { layouts.assign(new_in_layouts.begin(), new_in_layouts.end()); @@ -202,7 +207,7 @@ inline Array > BinaryBroadcastLayout(const Attrs& attrs, */ static inline std::tuple, Array, bool> InferCorrectLayouts( const Call& call, const Array& new_in_layouts, const Array& old_in_layouts, - const Array>& old_in_shapes) { + const Array& old_in_types) { static auto finfer_layout = Op::GetAttr("FInferCorrectLayout"); if (!call->op.as()) { return std::make_tuple<>(Array(nullptr), Array(nullptr), false); @@ -212,7 +217,7 @@ static inline std::tuple, Array, bool> InferCorrectLayouts if (finfer_layout.count(op)) { Array> inferred_layouts; inferred_layouts = - finfer_layout[op](call->attrs, new_in_layouts, old_in_layouts, old_in_shapes); + finfer_layout[op](call->attrs, new_in_layouts, old_in_layouts, old_in_types); CHECK_EQ(inferred_layouts.size(), 2) << "FInferCorrectLayout should return an array with size of 2"; for (auto x : inferred_layouts) { diff --git a/src/relay/pass/transform_layout.h b/src/relay/pass/transform_layout.h index f6c5e9af6d62..ab0d5a6c883a 100644 --- a/src/relay/pass/transform_layout.h +++ b/src/relay/pass/transform_layout.h @@ -222,7 +222,6 @@ template Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const NodeRef& ctx) { std::vector> inputs; std::vector normal_new_args; - Array> input_shapes; // NOTE: discard the "const" qualifier // TransformMemorizer memorizer = Downcast(ctx); @@ -269,22 +268,16 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Nod old_in.push_back(inp->old_layout); new_in.push_back(inp->new_layout); } + tvm::Array types; for (auto arg : ref_call->args) { - if (arg->IsInstance()) { // flatten tuple - Tuple tuple_arg = Downcast(arg); - for (auto x : tuple_arg->fields) { - input_shapes.push_back(x->type_as()->shape); - } - } else { - input_shapes.push_back(arg->type_as()->shape); - } + types.push_back(arg->checked_type()); } // old_in, old_out = op.infer(old_in) bool success = false; std::tie(old_in, old_out, success) = - InferCorrectLayouts(ref_call, Array(nullptr), old_in, input_shapes); + InferCorrectLayouts(ref_call, Array(nullptr), old_in, types); if (!success) { return Expr(nullptr); } @@ -304,7 +297,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Nod if (new_call->op->IsInstance()) { success = false; std::tie(new_in2, new_out, success) = - InferCorrectLayouts(new_call, new_in, old_in, input_shapes); + InferCorrectLayouts(new_call, new_in, old_in, types); if (!success) { return Expr(nullptr); } diff --git a/src/relay/qnn/op/add.cc b/src/relay/qnn/op/add.cc index 96fab5fa800b..62c9f126a433 100644 --- a/src/relay/qnn/op/add.cc +++ b/src/relay/qnn/op/add.cc @@ -25,6 +25,7 @@ #include #include #include "../../pass/pattern_util.h" +#include "../../pass/infer_layout_util.h" #include "../util.h" #include "op_common.h" @@ -32,6 +33,23 @@ namespace tvm { namespace relay { namespace qnn { +/*! \brief Infer layout for QNN binary broadcast operators */ +Array > QnnBinaryBroadcastLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { + // Use Relay Binary Broadcast Infer correct layout. + auto layouts = BinaryBroadcastLayout(attrs, new_in_layouts, old_in_layouts, old_in_types); + + // Fill the layouts of remaining input tensors - scales and zero points. The layouts of these + // tensors can be treated as C. + Layout channel_layout = Layout("C"); + Array input_layouts = {layouts[0][0], layouts[0][1], channel_layout, channel_layout, + channel_layout, channel_layout, channel_layout, channel_layout}; + Array output_layouts = layouts[1]; + return {input_layouts, output_layouts}; +} + /* * \brief Canonicalizes the QNN add op. * \param attrs The QNN concatenate attrs. @@ -117,7 +135,8 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array& new_args, QNN_REGISTER_BINARY_OP("add") .describe("Elementwise add with with broadcasting for quantized tensors.") .set_support_level(11) -.set_attr("FTVMQnnCanonicalize", QnnAddCanonicalize); +.set_attr("FTVMQnnCanonicalize", QnnAddCanonicalize) +.set_attr("FInferCorrectLayout", QnnBinaryBroadcastLayout); } // namespace qnn } // namespace relay diff --git a/src/relay/qnn/op/concatenate.cc b/src/relay/qnn/op/concatenate.cc index 321c0dcce578..ff95f1fee30c 100644 --- a/src/relay/qnn/op/concatenate.cc +++ b/src/relay/qnn/op/concatenate.cc @@ -28,6 +28,7 @@ #include #include "../../op/tensor/transform.h" #include "../../pass/pattern_util.h" +#include "../../pass/infer_layout_util.h" #include "../util.h" namespace tvm { @@ -78,6 +79,43 @@ Expr MakeQnnConcatenate(Expr data, Expr input_scales, Expr input_zero_points, Ex Attrs(attrs), {}); } +Array> QnnConcatenateLayout(const Attrs& attrs, const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { + // Collect the layouts and types to reuse Relay Concatenate Infer Correct Layout. + CHECK_EQ(old_in_types.size(), 5); + auto input_tuple_type = old_in_types[0].as(); + CHECK(input_tuple_type); + auto num_input_tensors = input_tuple_type->fields.size(); + + Array relay_new_in_layouts(nullptr); + if (new_in_layouts.defined()) { + relay_new_in_layouts = + Array(new_in_layouts.begin(), new_in_layouts.begin() + num_input_tensors); + } + Array relay_old_in_layouts(nullptr); + if (old_in_layouts.defined()) { + relay_old_in_layouts = + Array(old_in_layouts.begin(), old_in_layouts.begin() + num_input_tensors); + } + + // Use Relay Concatenate Infer Correct layout to infer the layouts for data tensors. + auto layouts = + ConcatenateLayout(attrs, relay_new_in_layouts, relay_old_in_layouts, {old_in_types[0]}); + + // Fill the layouts of remaining input tensors - scales and zero points. The layouts of these + // tensors can be treated as channel layout. Total number of these tensors are 2 * num of data + // tensors (scale and zero point for each input data tensor) + 2 for the output data tensor. + Layout channel_layout = Layout("C"); + Array input_layouts = layouts[0]; + + for (size_t i = 0; i < 2 * num_input_tensors + 2; i++) { + input_layouts.push_back(channel_layout); + } + Array output_layouts = layouts[1]; + return {input_layouts, output_layouts}; +} + /* * \brief Canonicalizes the QNN concatenate op. * \param attrs The QNN concatenate attrs. @@ -159,7 +197,8 @@ RELAY_REGISTER_OP("qnn.concatenate") .add_argument("output_zero_point", "Tensor", "The quantization zero_point of the output tensor.") .set_support_level(11) .add_type_rel("QnnConcatenate", QnnConcatenateRel) -.set_attr("FTVMQnnCanonicalize", ConcatenateQnnCanonicalize); +.set_attr("FTVMQnnCanonicalize", ConcatenateQnnCanonicalize) +.set_attr("FInferCorrectLayout", QnnConcatenateLayout); TVM_REGISTER_API("relay.qnn.op._make.concatenate") .set_body_typed(MakeQnnConcatenate); diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index b8b2f9237ec3..5eb40a04dea6 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -29,6 +29,7 @@ #include #include "../../op/nn/convolution.h" #include "../../pass/pattern_util.h" +#include "../../pass/infer_layout_util.h" #include "../util.h" namespace tvm { @@ -68,6 +69,23 @@ bool QnnConv2DRel(const Array& types, int num_inputs, const Attrs& attrs, return Conv2DRel(tensor_types, 3, attrs, reporter); } +Array> QnnConvInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { + // Use Relay Conv2D Infer correct layout. + auto layouts = + Conv2DInferCorrectLayout(attrs, new_in_layouts, old_in_layouts, old_in_types); + + // Fill the layouts of remaining input tensors - scales and zero points. The layouts of these + // tensors can be treated as channel layout. + Layout channel_layout = Layout("C"); + Array input_layouts = {layouts[0][0], layouts[0][1], channel_layout, + channel_layout, channel_layout, channel_layout}; + Array output_layouts = layouts[1]; + return {input_layouts, output_layouts}; +} + bool is_depthwise(const Conv2DAttrs* param) { return param->channels.defined() && tvm::ir::Equal(param->channels, param->groups) && param->groups != 1; @@ -681,7 +699,8 @@ operator to understand how to scale back the int32 output to (u)int8. .add_argument("weight_zero_point", "Tensor", "The quantization zero_point of the weight tensor.") .set_support_level(11) .add_type_rel("QnnConv2D", QnnConv2DRel) -.set_attr("FTVMQnnCanonicalize", QnnConv2DCanonicalize); +.set_attr("FTVMQnnCanonicalize", QnnConv2DCanonicalize) +.set_attr("FInferCorrectLayout", QnnConvInferCorrectLayout); TVM_REGISTER_API("relay.qnn.op._make.conv2d").set_body_typed(MakeQnnConv2D); diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index f4118c6e47d6..e36919b07f7f 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -26,6 +26,7 @@ #include #include #include "../../pass/pattern_util.h" +#include "../../pass/infer_layout_util.h" #include "../util.h" namespace tvm { @@ -34,6 +35,79 @@ namespace qnn { TVM_REGISTER_NODE_TYPE(RequantizeAttrs); +Array> RequantizeInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { + RequantizeAttrs* param = const_cast(attrs.as()); + + Array> old_in_shapes; + for (auto old_in_t : old_in_types) { + CHECK(old_in_t.as()); + old_in_shapes.push_back(old_in_t.as()->shape); + } + + Array input_layouts, output_layouts; + if (new_in_layouts.defined()) { + // Adapt to new layout. The axis has to change. + // Record original reduce axis. Convert to the modified layout axis. + CHECK_EQ(new_in_layouts.size(), 5); + CHECK_EQ(old_in_layouts.size(), 5); + + // 1) Get the axis. + int axis = param->axis; + axis = (axis == -1) ? old_in_shapes[0].size() - 1 : axis; + + // 2) Collect the original axis + std::string old_dim = old_in_layouts[0][axis].name(); + + // 3) Collect the new axes by walking new_layout. + tvm::Integer new_axis; + std::string new_layout_string = ""; + int axis_index = 0; + for (auto iter_var : new_in_layouts[0]->axes) { + const auto& layout_axis = LayoutAxis::Get(iter_var); + const std::string& layout_dim = layout_axis.name(); + if (old_dim == layout_dim) { + new_axis = tvm::Integer(axis_index); + } + // Collect only the primal axis. + if (layout_axis.IsPrimal()) { + new_layout_string += layout_dim; + axis_index++; + } + } + + // 4) Set the new axis and layout. + Layout new_layout = Layout(new_layout_string); + + // Fill the layouts of remaining input tensors - scales and zero points. The layouts of these + // tensors can be treated as channel layout. + Layout channel_layout = Layout("C"); + input_layouts = {new_layout, channel_layout, channel_layout, channel_layout, channel_layout}; + output_layouts = {new_layout}; + param->axis = new_axis; + } else if (old_in_layouts.defined()) { + // If the new layout is undefined, set the old layout as the inferred layout. + CHECK_EQ(old_in_layouts.size(), 5); + + Layout old_layout = old_in_layouts[0]; + + // Fill the layouts of remaining input tensors - scales and zero points. The layouts of these + // tensors can be treated as channel layout. + Layout channel_layout = Layout("C"); + input_layouts = {old_layout, channel_layout, channel_layout, channel_layout, channel_layout}; + output_layouts = {old_layout}; + } else { + // Set the layouts to undef. + Layout undef = Layout::Undef(); + input_layouts = Array(5, undef); + output_layouts = {undef}; + } + + return Array>{input_layouts, output_layouts}; +} + // Lowering of qnn.requantize op /* @@ -244,7 +318,8 @@ Q_output = zp_output + (scale_input)/(scale_output) * (Q_input - zp_input) .add_argument("output_zero_point", "Tensor", "The quantization zero_point of the output tensor.") .set_support_level(11) .add_type_rel("Requantize", RequantizeRel) -.set_attr("FTVMQnnCanonicalize", RequantizeQnnCanonicalize); +.set_attr("FTVMQnnCanonicalize", RequantizeQnnCanonicalize) +.set_attr("FInferCorrectLayout", RequantizeInferCorrectLayout); TVM_REGISTER_API("relay.qnn.op._make.requantize") .set_body_typed(MakeRequantize);