From fb6b4b7eb3095831b0a0537d3a3fed25b249e6d9 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 12 Jun 2019 00:35:09 +0000 Subject: [PATCH] [Relay] [Quantization] Protoyping the quantized convolution and requantized op Features - New quantized conv2D and requantize op in Relay - Python API interface to instantiate the Relay op - Infer Type implemented - Lowering of quantized op to low-level Relay ops --- include/tvm/relay/attrs/qnn.h | 131 +++++ include/tvm/relay/quantize_util.h | 139 ++++++ python/tvm/relay/op/__init__.py | 1 + python/tvm/relay/op/qnn/__init__.py | 21 + python/tvm/relay/op/qnn/_make.py | 20 + python/tvm/relay/op/qnn/qnn.py | 172 +++++++ python/tvm/relay/quantize/__init__.py | 1 + python/tvm/relay/quantize/rewrite.py | 38 ++ src/relay/op/nn/convolution.cc | 70 ++- src/relay/op/nn/requantize.cc | 91 ++++ src/relay/pass/pattern_util.h | 45 ++ src/relay/pass/quantize_rewrite.cc | 330 +++++++++++++ tests/python/unittest/test_quantized_ops.py | 500 ++++++++++++++++++++ 13 files changed, 1555 insertions(+), 4 deletions(-) create mode 100644 include/tvm/relay/attrs/qnn.h create mode 100644 include/tvm/relay/quantize_util.h create mode 100644 python/tvm/relay/op/qnn/__init__.py create mode 100644 python/tvm/relay/op/qnn/_make.py create mode 100644 python/tvm/relay/op/qnn/qnn.py create mode 100644 python/tvm/relay/quantize/rewrite.py create mode 100644 src/relay/op/nn/requantize.cc create mode 100644 src/relay/pass/quantize_rewrite.cc create mode 100644 tests/python/unittest/test_quantized_ops.py diff --git a/include/tvm/relay/attrs/qnn.h b/include/tvm/relay/attrs/qnn.h new file mode 100644 index 000000000000..8d62c66bac2f --- /dev/null +++ b/include/tvm/relay/attrs/qnn.h @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/attrs/nn.h + * \brief Auxiliary attributes for nn operators. + */ +#ifndef TVM_RELAY_ATTRS_NN_QUANTIZE_H_ +#define TVM_RELAY_ATTRS_NN_QUANTIZE_H_ + +#include +#include + +namespace tvm { +namespace relay { + +/*! \brief Attribute for quantized conv2d operator */ +struct QConv2DAttrs : public tvm::AttrsNode { + // Traditional conv2d attributes. + Array strides; + Array padding; + Array dilation; + int groups; + IndexExpr channels; + Array kernel_size; + std::string data_layout; + std::string kernel_layout; + std::string out_layout; + DataType out_dtype; + + // Quantization related attributes. + int32_t input_zero_point; + int32_t kernel_zero_point; + + TVM_DECLARE_ATTRS(QConv2DAttrs, "relay.attrs.QConv2DAttrs") { + TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) + .describe("If padding is non-zero, then the input is implicitly zero-padded" + "on both sides for padding number of points"); + TVM_ATTR_FIELD(dilation).set_default(Array({1, 1})) + .describe("Specifies the dilation rate to use for dilated convolution."); + TVM_ATTR_FIELD(groups).set_default(1) + .describe("Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(channels) + .describe("The number of output channels in the convolution." + " If it is not set, inferred by shape of the weight.") + .set_default(NullValue()); + TVM_ATTR_FIELD(kernel_size) + .describe("Specifies the dimensions of the convolution window.") + .set_default(NullValue >()); + TVM_ATTR_FIELD(data_layout).set_default("NCHW") + .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout).set_default("OIHW") + .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." + "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_layout).set_default("") + .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Default to be same as input layout."); + TVM_ATTR_FIELD(out_dtype) + .set_default(NullValue()) + .describe("Output data type, set to explicit type under mixed precision setting"); + TVM_ATTR_FIELD(input_zero_point) + .describe("The zero point of the input tensor."); + TVM_ATTR_FIELD(kernel_zero_point) + .describe("The zero point of the kernel tensor."); + } +}; + + +/*! \brief Attribute for requantize operator */ +struct RequantizeAttrs : public tvm::AttrsNode { + double input_scale; + int32_t input_zero_point; + double output_scale; + int32_t output_zero_point; + bool use_int_compute; + std::string rounding_mode; + DataType out_dtype; + + TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") { + TVM_ATTR_FIELD(input_zero_point) + .describe("The zero point of the input tensor."); + TVM_ATTR_FIELD(output_zero_point) + .describe("The zero point of the output tensor."); + TVM_ATTR_FIELD(input_scale) + .describe("The scale of the input tensor."); + TVM_ATTR_FIELD(output_scale) + .describe("The scale of the output tensor."); + TVM_ATTR_FIELD(use_int_compute).set_default(false) + .describe("When true, the integer computation is used to handle output scale"); + TVM_ATTR_FIELD(out_dtype) + .set_default(NullValue()) + .describe("Output data type, set to explicit type under mixed precision setting"); + TVM_ATTR_FIELD(rounding_mode).set_default("FE_UPWARD") + .describe("Defines the rounding direction when the value is midway between" + "two representable values. There are two supported modes - FE_UPWARD" + "or FE_AWAY_FROM_ZERO. More context can be found at" + "https://www.gnu.org/software/libc/manual/html_node/Rounding.html"); + } +}; + + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_ATTRS_NN_QUANTIZE_H_ diff --git a/include/tvm/relay/quantize_util.h b/include/tvm/relay/quantize_util.h new file mode 100644 index 000000000000..bb054fb8fb65 --- /dev/null +++ b/include/tvm/relay/quantize_util.h @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file nnvm/compiler/quantize_util.h + * \brief Utility methods needs for quantized ops that can be shared + */ + +#ifndef TVM_QUANTIZE_UTIL_H +#define TVM_QUANTIZE_UTIL_H + +#include +#include "./base.h" + +namespace tvm { +namespace relay { + +inline bool is_Int8(const DataType& dtype) { + return dtype == Int(8); +} + +inline bool is_UInt8(const DataType& dtype) { + return dtype == UInt(8); +} + + +inline bool is_Int16(const DataType& dtype) { + return dtype == Int(16); +} + +inline bool is_UInt16(const DataType& dtype) { + return dtype == UInt(16); +} + +inline bool is_Int32(const DataType& dtype) { + return dtype == Int(32); +} + +inline bool is_UInt32(const DataType& dtype) { + return dtype == UInt(32); +} + + + +inline bool is_Float32(const DataType& dtype) { + return dtype == Float(32); +} + +inline bool is_quantized_type(const DataType& dtype) { + return is_Int8(dtype) || is_UInt8(dtype) + || is_Int16(dtype) || is_UInt16(dtype); +} + +enum class QuantizeOpType : uint8_t { + Quantize_Requantize, + Dequantize, + Requantize +}; + +inline bool is_valid_quantized_op_input_type(const QuantizeOpType &op_type, const DataType &in_dtype) { + switch(op_type) { + case QuantizeOpType::Quantize_Requantize: + return is_Float32(in_dtype) || is_quantized_type(in_dtype); + case QuantizeOpType ::Dequantize: + return is_quantized_type(in_dtype); + case QuantizeOpType ::Requantize: + return is_Int16(in_dtype) || is_Int32(in_dtype); + default: + return false; + } +} + +inline bool is_valid_quantized_op_output_type(const QuantizeOpType &op_type, const DataType &in_dtype) { + switch(op_type) { + case QuantizeOpType::Quantize_Requantize: + return is_quantized_type(in_dtype); + case QuantizeOpType::Dequantize: + return is_Float32(in_dtype); + default: + return false; + } +} + +inline const int32_t get_qmin(const DataType& dtype) { + if (is_Int8(dtype)) { + return std::numeric_limits::min(); + } else if (is_UInt8(dtype)) { + return std::numeric_limits::min(); + } else if (is_Int16(dtype)) { + return std::numeric_limits::min(); + } else if (is_UInt16(dtype)) { + return std::numeric_limits::min(); + } else if (is_Int32(dtype)) { + return std::numeric_limits::min(); + } else if (is_UInt32(dtype)) { + return std::numeric_limits::min(); + } + LOG(FATAL) << "Type not supported\n"; + return -1; +} + + +inline const int32_t get_qmax(const DataType& dtype) { + if (is_Int8(dtype)) { + return std::numeric_limits::max(); + } else if (is_UInt8(dtype)) { + return std::numeric_limits::max(); + } else if (is_Int16(dtype)) { + return std::numeric_limits::max(); + } else if (is_UInt16(dtype)) { + return std::numeric_limits::max(); + } else if (is_Int32(dtype)) { + return std::numeric_limits::max(); + } else if (is_UInt32(dtype)) { + return std::numeric_limits::max(); + } + LOG(FATAL) << "Type not supported\n"; + return -1; +} + +} // namespace relay +} // namespace tvm +#endif //TVM_QUANTIZE_UTIL_H diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index a27ab1dc50ff..1d634ef18fc0 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -26,6 +26,7 @@ from .transform import * from .algorithm import * from . import nn +from . import qnn from . import annotation from . import image from . import vision diff --git a/python/tvm/relay/op/qnn/__init__.py b/python/tvm/relay/op/qnn/__init__.py new file mode 100644 index 000000000000..e42063eed26c --- /dev/null +++ b/python/tvm/relay/op/qnn/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import +"""Neural network related operators.""" +from __future__ import absolute_import as _abs +from .qnn import * +# from . import _nn diff --git a/python/tvm/relay/op/qnn/_make.py b/python/tvm/relay/op/qnn/_make.py new file mode 100644 index 000000000000..b1695629b8f9 --- /dev/null +++ b/python/tvm/relay/op/qnn/_make.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Constructor APIs""" +from ...._ffi.function import _init_api + +_init_api("relay.op.qnn._make", __name__) diff --git a/python/tvm/relay/op/qnn/qnn.py b/python/tvm/relay/op/qnn/qnn.py new file mode 100644 index 000000000000..95abbb807d18 --- /dev/null +++ b/python/tvm/relay/op/qnn/qnn.py @@ -0,0 +1,172 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +#pylint: disable=invalid-name, too-many-lines +"""Neural network operations.""" +from __future__ import absolute_import as _abs +from . import _make + + +def conv2d(quantized_data, + quantized_weight, + input_zero_point, + kernel_zero_point, + strides=(1, 1), + padding=(0, 0), + dilation=(1, 1), + groups=1, + channels=None, + kernel_size=None, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="", + out_dtype="int32"): + r"""Quantized 2D convolution. + + This operator takes the quantized_weight as the convolution kernel + and convolves it with quantized_data to produce an output quantized tensor. + The scale of the output quantized tensor is the prodcut of the weight_scale + and input_scale of the input quantized tensors. The zero point of the output + quantized tensor is 0. By default, the dtype of output is int32. Please also + see Requantize operator to understand the dtype scaling back to (u)int8. + + In the default case, where the data_layout is `NCHW` + and kernel_layout is `OIHW`, conv2d takes in + a quantized_data Tensor with shape `(batch_size, in_channels, height, width)`, + and a quantized_weight Tensor with shape `(channels, in_channels, kernel_size[0], kernel_size[1])` + to produce an output Tensor with the following rule: + + .. math:: + + \mbox{out}[b, c, y, x] = \sum_{dy, dx, k} + \mbox{quantized_data}[b, k, \mbox{strides}[0] * y + dy, \mbox{strides}[1] * x + dx] * + \mbox{quantized_weight}[c, k, dy, dx] + + Padding and dilation are applied to quantized_data and quantized_weight respectively before the computation. + This operator accepts quantized_data layout specification. + Semantically, the operator will convert the layout to the canonical layout + (`NCHW` for quantized_data and `OIHW` for quantized_weight), perform the computation, + then convert to the out_layout. + + + Parameters + ---------- + quantized_data : tvm.relay.Expr + The input quantized_data to the operator. + + quantized_weight : tvm.relay.Expr + The quantized_weight expressions. + + input_zero_point: int + The zero point of the quantized_data distribution. + + kernel_zero_point: int + The zero point of the quantized_kernel distribution. + + strides : tuple of int, optional + The strides of convolution. + + padding : tuple of int, optional + The padding of convolution on both sides of inputs before convolution. + + dilation : tuple of int, optional + Specifies the dilation rate to be used for dilated convolution. + + groups : int, optional + Number of groups for grouped convolution. + + channels : int, optional + Number of output channels of this convolution. + + kernel_size : tuple of int, optional + The spatial of the convolution kernel. + + data_layout : str, optional + Layout of the input. + + kernel_layout : str, optional + Layout of the quantized_weight. + + out_layout : str, optional + Layout of the output, by default, out_layout is the same as data_layout + + out_dtype : str, optional + Specifies the output quantized_data type for mixed precision conv2d. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + return _make.conv2d(quantized_data, quantized_weight, + input_zero_point, kernel_zero_point, + strides, padding, dilation, + groups, channels, kernel_size, + data_layout, kernel_layout, out_layout, out_dtype) + +def requantize(input_data, input_zero_point, input_scale, output_zero_point, + output_scale, out_dtype="int32", use_int_compute=False, + rounding_mode="FE_UPWARD"): + r"""Requantized operator. + + The requantize operator converts one quantized tensor to another quantized + tensor. For the output tensor, we are provided with output scale and zero + point. The computation looks like this + + Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input) + + The above computation can be done in floating point as the scales are in + FP32. Alternatively, we can approximate floating point with fixed point + computation. This is controlled by use_int_compute. + + Parameters + ---------- + quantized_data : tvm.relay.Expr + The input quantized_data to the operator. + + input_scale: float + The float scalar to scale the quantized_data int8 values back to FP32. + + output_scale: float + The float scalar to scale the quantized_output int8 values back to FP32. + + input_zero_point: int + The zero point of the quantized_data distribution. + + output_zero_point: int + The zero point of the quantized_output distribution. + + out_dtype : str, optional + Specifies the output quantized_data type for mixed precision conv2d. + + use_int_compute : bool, optional + Use fully integer computation for requantizing. + + rounding_mode : string, optional + Defines the rounding direction when the value is midway between two + representable values. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + assert rounding_mode in ("FE_UPWARD", "FE_AWAY_FROM_ZERO"),\ + "Unsupported rounding mode" + + return _make.requantize(input_data, input_zero_point, input_scale, + output_zero_point, output_scale, out_dtype, + use_int_compute, rounding_mode) diff --git a/python/tvm/relay/quantize/__init__.py b/python/tvm/relay/quantize/__init__.py index 45bb62e66853..8da4e7953566 100644 --- a/python/tvm/relay/quantize/__init__.py +++ b/python/tvm/relay/quantize/__init__.py @@ -19,4 +19,5 @@ from __future__ import absolute_import as _abs from .quantize import * +from .rewrite import * from ._annotate import register_annotate_function diff --git a/python/tvm/relay/quantize/rewrite.py b/python/tvm/relay/quantize/rewrite.py new file mode 100644 index 000000000000..89429e522115 --- /dev/null +++ b/python/tvm/relay/quantize/rewrite.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +#pylint: disable=unused-argument +"""Automatic quantization toolkit.""" +from __future__ import absolute_import + +from . import _quantize +from .. import expr as _expr + +def rewrite(expr): + """ + Rewrites the high-level quantized ops into low-level exisiting Relay ops. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + + Returns + ------- + expr : tvm.relay.Expr + The output expression. + """ + return _quantize.rewrite(expr) diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 97cba7964000..ed50080346bb 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include "../../pass/alter_op_layout.h" @@ -35,6 +36,7 @@ namespace relay { // relay.nn.conv2d TVM_REGISTER_NODE_TYPE(Conv2DAttrs); +template bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -46,7 +48,7 @@ bool Conv2DRel(const Array& types, static const Layout kNCHW("NCHW"); static const Layout kOIHW("OIHW"); - const Conv2DAttrs* param = attrs.as(); + const auto param = attrs.as(); CHECK(param != nullptr); const Layout in_layout(param->data_layout); const Layout kernel_layout(param->kernel_layout); @@ -191,7 +193,7 @@ with the layer input to produce a tensor of outputs. .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(2) -.add_type_rel("Conv2D", Conv2DRel) +.add_type_rel("Conv2D", Conv2DRel) .set_attr("FInferCorrectLayout", Conv2DInferCorrectLayout); @@ -701,7 +703,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc") .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(10) -.add_type_rel("Conv2D", Conv2DRel) +.add_type_rel("Conv2D", Conv2DRel) .set_attr("FInferCorrectLayout", Conv2DInferCorrectLayout); @@ -751,7 +753,7 @@ RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc") .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(10) -.add_type_rel("Conv2D", Conv2DRel) +.add_type_rel("Conv2D", Conv2DRel) .set_attr("FInferCorrectLayout", Conv2DInferCorrectLayout); @@ -896,6 +898,66 @@ Expr MakeDeformableConv2D(Expr data, TVM_REGISTER_API("relay.op.nn._make.deformable_conv2d") .set_body_typed(MakeDeformableConv2D); +// relay.op.qnn.conv2d +TVM_REGISTER_NODE_TYPE(QConv2DAttrs); + +// Positional relay function to create quantized conv2d operator +// used by frontend FFI. +Expr MakeQConv2D(Expr quantized_data, + Expr quantized_weight, + int32_t input_zero_point, + int32_t kernel_zero_point, + Array strides, + Array padding, + Array dilation, + int groups, + IndexExpr channels, + Array kernel_size, + std::string data_layout, + std::string kernel_layout, + std::string out_layout, + DataType out_dtype) { + auto attrs = make_node(); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->dilation = std::move(dilation); + attrs->groups = groups; + attrs->channels = std::move(channels); + attrs->kernel_size = std::move(kernel_size); + attrs->data_layout = std::move(data_layout); + attrs->kernel_layout = std::move(kernel_layout); + attrs->out_layout = std::move(out_layout); + attrs->out_dtype = std::move(out_dtype); + attrs->input_zero_point = std::move(input_zero_point); + attrs->kernel_zero_point = std::move(kernel_zero_point); + static const Op& op = Op::Get("qnn.conv2d"); + return CallNode::make(op, {quantized_data, quantized_weight}, Attrs(attrs), {}); +} + +RELAY_REGISTER_OP("qnn.conv2d") +.describe(R"code(2D quantized convolution layer. + +This operator creates a quantized convolution kernel that is convolved +with the quantized input to produce a tensor of quantized outputs. The +operator is further lowered to existing set of Relay operators. + +- **quantized_data**: This depends on the `layout` parameter. Input is 4D array of shape + (batch_size, in_channels, height, width) if `layout` is `NCHW`. +- **quantized_weight**: (channels, in_channels, kernel_size[0], kernel_size[1]) +- **quantized_out**: This depends on the `layout` parameter. Output is 4D array of shape + (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. + +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.QConv2DAttrs") +.set_num_inputs(2) +.add_argument("quantized_data", "Tensor", "The quantized input quantized_data tensor.") +.add_argument("quantized_weight", "Tensor", "The quantized quantized_weight tensor.") +.set_support_level(10) +.add_type_rel("QConv2D", Conv2DRel); + +TVM_REGISTER_API("relay.op.qnn._make.conv2d") +.set_body_typed(MakeQConv2D); + } // namespace relay } // namespace tvm diff --git a/src/relay/op/nn/requantize.cc b/src/relay/op/nn/requantize.cc new file mode 100644 index 000000000000..285528993f6f --- /dev/null +++ b/src/relay/op/nn/requantize.cc @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file requantize.cc + * \brief Quantized convolution operators + */ + +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(RequantizeAttrs); + + +bool RequantizeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + const auto input_dtype = data->dtype; + CHECK(is_valid_quantized_op_input_type(QuantizeOpType::Requantize, input_dtype)) + << "Input type should be a quantized type (u)int8 or (u)int16 but was " << input_dtype; + + const Array oshape = data->shape; + // assign output type + const RequantizeAttrs* param = attrs.as(); + reporter->Assign(types[1], TensorTypeNode::make(oshape, param->out_dtype)); + return true; +} + +// Positional relay function to create quantized conv2d operator +// used by frontend FFI. +Expr MakeRequantize(Expr data, + int32_t input_zero_point, + double input_scale, + int32_t output_zero_point, + double output_scale, + DataType out_dtype, + bool use_int_compute, + std::string rounding_mode) { + auto attrs = make_node(); + attrs->out_dtype = std::move(out_dtype); + attrs->input_zero_point = std::move(input_zero_point); + attrs->output_zero_point = std::move(output_zero_point); + attrs->input_scale = std::move(input_scale); + attrs->output_scale = std::move(output_scale); + attrs->use_int_compute = std::move(use_int_compute); + attrs->rounding_mode = std::move(rounding_mode); + static const Op& op = Op::Get("qnn.requantize"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +RELAY_REGISTER_OP("qnn.requantize") +.describe(R"code(Requantize operator. + +FIXME +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.RequantizeAttrs") +.set_num_inputs(1) +.add_argument("data", "Tensor", "The quantized input tensor.") +.set_support_level(10) +.add_type_rel("Requantize", RequantizeRel); + +TVM_REGISTER_API("relay.op.qnn._make.requantize") +.set_body_typed(MakeRequantize); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 5c303905968e..a1c3fadcbb11 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -372,7 +372,52 @@ inline Expr Copy(Expr data) { return CallNode::make(op, {data}, Attrs(), {}); } +inline Expr Conv2D(Expr data, + Expr weight, + Array strides, + Array padding, + Array dilation, + int groups, + IndexExpr channels, + Array kernel_size, + std::string data_layout, + std::string kernel_layout, + std::string out_layout, + DataType out_dtype) { + auto attrs = make_node(); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->dilation = std::move(dilation); + attrs->groups = groups; + attrs->channels = std::move(channels); + attrs->kernel_size = std::move(kernel_size); + attrs->data_layout = std::move(data_layout); + attrs->kernel_layout = std::move(kernel_layout); + attrs->out_layout = std::move(out_layout); + attrs->out_dtype = std::move(out_dtype); + static const Op& op = Op::Get("nn.conv2d"); + return CallNode::make(op, {data, weight}, Attrs(attrs), {}); +} + +inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) { + static const Op& op = Op::Get("where"); + return CallNode::make(op, {condition, x, y}); +} +inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) { + static const Op& op = Op::Get("greater_equal"); + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); +} + +inline Expr Full(Expr fill_value, + Array shape, + DataType dtype) { + auto attrs = make_node(); + attrs->shape = std::move(shape); + attrs->dtype = std::move(dtype); + static const Op& op = Op::Get("full"); + return CallNode::make(op, {fill_value}, Attrs(attrs), {}); +} Expr MakeConcatenate(Expr data, int axis); Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides); diff --git a/src/relay/pass/quantize_rewrite.cc b/src/relay/pass/quantize_rewrite.cc new file mode 100644 index 000000000000..0e94aa1914c4 --- /dev/null +++ b/src/relay/pass/quantize_rewrite.cc @@ -0,0 +1,330 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file quantize_rewrite.cc + * \brief Lower quantized ops to exisiting Relay ops. + */ + +#include +#include +#include +#include +#include +#include "pattern_util.h" + +namespace tvm { +namespace relay { + + +// Lowering of qnn.requantize op + +/* + * Converts a floating point number so that it can be represented by integers. + * The representation is + * float_number = (fixed_point_multiplier) * 2^(shift) + * + * The fixed_point_multiplier is a number between 0.5 and 1. This is represented + * by an integer number. For example, if it is int32, then the decimal point + * exists between bit 31 and 30 from LSB (or between first and second bit from + * the left). + * + * Some examples are + * 0.25 = (0.5) * 2^(-1) + * 0.125 = (0.5) * 2^(-2) + */ +void GetFixedPointMultiplierShift(double double_multiplier, + int32_t* fixed_point_multiplier, int* shift, + const DataType& idtype) { + + int idtype_bits = idtype.bits(); + + if (double_multiplier == 0.) { + *fixed_point_multiplier = 0; + *shift = 0; + return; + } + const double q = std::frexp(double_multiplier, shift); + auto q_fixed = static_cast(std::round(q * (1ll << (idtype_bits - 1)))); + CHECK_LE(q_fixed, (1ll << (idtype_bits - 1))); + if (q_fixed == (1ll << (idtype_bits - 1))) { + q_fixed /= 2; + ++*shift; + } + CHECK_LE(q_fixed, std::numeric_limits::max()); + *fixed_point_multiplier = static_cast(q_fixed); +} + +/* + * Requantization using only integer computation. Here, the computation is + * converted to a fixed point computation by computing output multiplier and + * shift. This is useful, if the target device does not support/have very + * expensive floating point computations. + * + * Original compuation is scale_fp32 * quantized_tensor. To convert into + * integer computation, the multiplication with fp32 scalar can be replaced by + * multiplication with an int value and then right shifting the result. This + * approximates the floating point computation with a fixed point computation. + * + * The whole computation this can be broken down into following steps + * 1) Calculate the integer multiplier and integer shift. + * 2) Subtract the input integer point. + * 2) Multiply the integer fixed point multiplier with quantized tensor. + * 3) Round the result. + * 4) Right shift the result. + * 5) Add the output_zero_point. + * 6) Cast to the out_dtype. + * + */ +Expr RequantizeInt(const Expr& input_tensor, + const RequantizeAttrs*& param, const DataType& idtype, + const Array& out_shape) { + + double double_multiplier = param->input_scale/param->output_scale; + + // The multiplication will be performed in higher precision. Find the dtype. + int idtype_bits = idtype.bits(); + DataType up_idtype = Int(2 * idtype_bits); + + // 1) Calculating the integer multiplier and integer shift + int32_t fixed_point_multiplier; + int shift; + GetFixedPointMultiplierShift(double_multiplier, &fixed_point_multiplier, + &shift, idtype); + int left_shift = shift > 0 ? shift : 0; + int right_shift = shift > 0 ? 0 : -shift; + + // 2) Subtract the input_zero_point + auto tensor = input_tensor; + tensor = Cast(tensor, up_idtype); + if (param->input_zero_point != 0) { + auto input_zp = MakeConstantScalar(up_idtype, param->input_zero_point); + tensor = Subtract(tensor, input_zp); + } + + + + // 3) Multiply the integer multiplier + if (left_shift != 0) { + tensor = Multiply(tensor, MakeConstantScalar(up_idtype, 1 << left_shift)); + } + // Perform the multiplication in higher precision. + // If idtype is Int(32), the scalar is a fixed point value of int32 where the + // decimal point is between bits 31 and 30. After multiplying with + // input_tensor, the result in int64 where the decimal point is sitting + // between bits 31 and 30 (from the right, rightmost bit is bit 0). + Expr scalar = MakeConstantScalar(up_idtype, fixed_point_multiplier); + auto multiplied_t = Multiply(tensor, scalar); + + + // 4) Find the rounding scalar. This depends on where the final decimal point + // sits. As we will be right shifting the multiplied_t, we need to first + // calculate the totol_right_shift. + int total_right_shift = right_shift + idtype_bits - 1; + + tensor = multiplied_t; + Expr round_scalar; + if (param->rounding_mode == "FE_UPWARD") { + auto pos_rounder = MakeConstantScalar(up_idtype, (1ll << (total_right_shift - 1))); + round_scalar = pos_rounder; + } else if (param->rounding_mode == "FE_AWAY_FROM_ZERO") { + auto pos_rounder = MakeConstantScalar(up_idtype, (1ll << (total_right_shift - 1))); + auto neg_rounder = MakeConstantScalar(up_idtype, (1ll << (total_right_shift - 1)) - 1); + auto pos_rounder_t = Full(pos_rounder, out_shape, up_idtype); + auto neg_rounder_t = Full(neg_rounder, out_shape, up_idtype); + + auto zero = MakeConstantScalar(up_idtype, 0); + auto zero_t = Full(zero, out_shape, up_idtype); + round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, + neg_rounder_t); + } + // Add the rounding scalar. + tensor = Add(tensor, round_scalar); + + // 5) Simply right shift the result to get the final output. + auto scaled_int64_t = RightShift(tensor, + MakeConstantScalar(up_idtype, total_right_shift)); + + // 6) Add the output zero point. + auto output_zp = MakeConstantScalar(up_idtype, param->output_zero_point); + auto shifted_int64_t = Add(output_zp, scaled_int64_t); + + // 7) Clip to the out_dtype min/max. + // Find the right clip min/maxes. While clipping, it is necessary that + // clip_min and clip_max are within the dtype range of the input tensor to the + // clip operator. For example, if the input to clip operator is int8, but the + // out_dtype is uint8, we will get incorrect results, if we set max as 255. + auto q_min = std::max(get_qmin(param->out_dtype), get_qmin(idtype)); + auto q_max = std::min(get_qmax(param->out_dtype), get_qmax(idtype)); + auto clipped_t = Clip(shifted_int64_t, q_min, q_max); + auto requantized_output = Cast(clipped_t, param->out_dtype); + return requantized_output; +} + + +/* + * Requantization using floating computation. Here we can multiply the scale to + * the input_tensor, round to nearest integer and then cast back to int32. + */ +Expr RequantizeFloat(const Expr& input_tensor, + const RequantizeAttrs*& param, const DataType& idtype, + const Array& out_shape) { + double double_multiplier = param->input_scale/param->output_scale; + auto scalar_multiplier = MakeConstantScalar(Float(32), double_multiplier); + auto input_zp = MakeConstantScalar(idtype, param->input_zero_point); + auto output_zp = MakeConstantScalar(Float(32), param->output_zero_point); + + // Multiply the convolved tensor with the new scale. + auto shifted_input_t = Subtract(input_tensor, input_zp); + auto casted_t = Cast(shifted_input_t, Float(32)); + auto multiplied_t = Multiply(casted_t, scalar_multiplier); + auto shifted_multiplied_t = Add(output_zp, multiplied_t); + auto rounded_t = Round(shifted_multiplied_t); + auto q_imin = get_qmin(idtype); + auto q_imax = get_qmax(idtype); + auto scaled_int32_t = Cast(Clip(rounded_t, q_imin, q_imax), + idtype); + + // Clip to the out_dtype min/max. + // Clip limits must be smaller than the dtype of the input tensor. + auto q_min = std::max(get_qmin(param->out_dtype), get_qmin(idtype)); + auto q_max = std::min(get_qmax(param->out_dtype), get_qmax(idtype)); + auto clipped_t = Clip(scaled_int32_t, q_min, q_max); + auto requantized_output = Cast(clipped_t, param->out_dtype); + return requantized_output; +} + +/* + * Lowering of the requantize operation. The requantize operator converts one + * quantized tensor to another quantized tensor. For the output tensor, we are + * provided with output scale and zero point. The computation looks like this + * + * Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input) + * + * The above computation can be done in floating point as the scales are in + * FP32. Alternatively, we can approximate floating point with fixed point + * computation. This is controlled by use_int_compute. + */ +Expr RequantizeForwardRewrite(const Call& ref_call, + const Array& new_args, const NodeRef& ctx) { + CHECK_EQ(new_args.size(), 1); + Expr quantized_data = new_args[0]; + const auto* param = ref_call->attrs.as(); + + // Find output shape. + Array out_shape; + auto ref_call_t = ref_call->checked_type(); + auto output_tt = ref_call_t.as(); + CHECK(output_tt != nullptr) << "Type information missing." + << " Please run infer_type pass."; + out_shape = output_tt->shape; + + // Find input dtype. + auto ref_input_t = ref_call->args[0]->checked_type(); + auto input_tt = ref_input_t.as(); + CHECK(input_tt != nullptr) << "Type information missing." + << " Please run infer_type pass."; + const auto input_dtype = input_tt->dtype; + + if (param->use_int_compute) { + return RequantizeInt(quantized_data, param, input_dtype, out_shape); + } else { + return RequantizeFloat(quantized_data, param, input_dtype, out_shape); + } +} + + +RELAY_REGISTER_OP("qnn.requantize") +.set_attr("FQuantizeForwardRewrite", RequantizeForwardRewrite); + +// Lowering of qnn.conv2d op + +/* + * Lowering of the quantized_convolution. + * + * A quantized tensor is represented in following manner + * A = scale_a x (QA - zp_A) + * where QA is quantized tensor, scale_a and zp_A are quantizations params. + * + * Quantized convlution convolves two quantized tensors and returns a quantized + * tensor of default dtype of int32, with scale equaling to the product of + * scales of input tensors, and a zero point of zero. + * + * For symmetric quantization, the zp_* for all tensors is 0. So, the lowering + * of qnn.conv2d is + * + * QA(n, ic, oh + r, ow + s) (conv) QW(oc, ic, r, s) + * + * For asymmetric computation, we can perform similar unrolling. We can find + * more details at + * https://discuss.tvm.ai/t/tf-lite-quantized-conv2d-operator-conversion/2651/8?u=janimesh + */ +Expr QConv2DForwardRewrite(const Call& ref_call, + const Array& new_args, const NodeRef& ctx) { + CHECK_EQ(new_args.size(), 2); + Expr quantized_data = new_args[0]; + Expr quantized_kernel = new_args[1]; + const auto* param = ref_call->attrs.as(); + + Array out_shape; + auto ref_call_t = ref_call->checked_type(); + auto output_tt = ref_call_t.as(); + CHECK(output_tt != nullptr) << "Type information missing." + << " Please run infer_type pass."; + out_shape = output_tt->shape; + + // Check for current quantization support. + CHECK_EQ(param->input_zero_point, 0) + << "Encountered non-zero zero point." + << " Only symmetric quantization supported for now."; + CHECK_EQ(param->kernel_zero_point, 0) + << "Encountered non-zero zero point." + << " Only symmetric quantization supported for now."; + + if (param->input_zero_point == 0 && param->kernel_zero_point == 0) { + Expr int8_conv = Conv2D(quantized_data, + quantized_kernel, + param->strides, + param->padding, + param->dilation, + param->groups, + param->channels, + param->kernel_size, + param->data_layout, + param->kernel_layout, + param->out_layout, + param->out_dtype); + return int8_conv; + } + LOG(FATAL) << "Only symmetric quantization supported"; + return Expr(); // to hide the warning. +} + +RELAY_REGISTER_OP("qnn.conv2d") +.set_attr("FQuantizeForwardRewrite", QConv2DForwardRewrite); + +TVM_REGISTER_API("relay._quantize.rewrite") +.set_body_typed([](const Expr& e) { + Expr ret = ForwardRewrite(e, "FQuantizeForwardRewrite", nullptr, nullptr); + return ret; +}); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/unittest/test_quantized_ops.py b/tests/python/unittest/test_quantized_ops.py new file mode 100644 index 000000000000..e0a43b281039 --- /dev/null +++ b/tests/python/unittest/test_quantized_ops.py @@ -0,0 +1,500 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import numpy as np +from tvm import relay +from tvm.relay.testing import create_workload +from tvm.contrib import graph_runtime + +# TODOs for janimesh before submitting this patch. +# TODO - Add tests for int8 input/weight dtype +# TODO - opt_level=0 fails mostly due to fusion. +# TODO - opt_level=3 fails, likely culprit kernel layout for int8 +# compute. Work with Rankyung to see if this is the culprit. Handle +# it in a separate patch. +rounding_modes = ["FE_UPWARD", "FE_AWAY_FROM_ZERO"] +def run_infer_type(expr): + mod = relay.Module.from_expr(expr) + mod = relay.transform.InferType()(mod) + entry = mod["main"] + return entry if isinstance(expr, relay.Function) else entry.body + + +def test_requantize(): + def verify(func, goldens): + with relay.build_config(opt_level=0): + graph, lib, params = relay.build(func, "llvm", params=None) + golden_data, golden_output = goldens + mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) + mod.set_input("quantized_data",golden_data) + mod.set_input(**params) + mod.run() + res = mod.get_output(0).asnumpy() + np.testing.assert_equal(res, golden_output) + + def get_func(data_shape, data_dtype, out_dtype, use_int_compute, + rounding_mode, input_scale, output_scale, input_zero_point=0, + output_zero_point=0): + quantized_data = relay.var("quantized_data", shape=data_shape, + dtype=data_dtype) + func = relay.op.qnn.requantize( + quantized_data, + input_zero_point=input_zero_point, + output_zero_point=output_zero_point, + input_scale=input_scale, + output_scale=output_scale, + rounding_mode=rounding_mode, + out_dtype=out_dtype, + use_int_compute=use_int_compute) + + func = relay.Function(relay.analysis.free_vars(func), + func) + func = run_infer_type(func) + func = relay.quantize.rewrite(func) + print(func) + return func + + + def run_tests(): + def same_scale_test(): + # Have same scales, everything within range + golden_data = np.arange(-100, 100, 1).astype('int32') + golden_output = golden_data + + for rounding_mode in rounding_modes: + for use_int_compute in [True, False]: + func = get_func(data_shape=(200, ), + data_dtype='int32', + out_dtype="int8", + use_int_compute=use_int_compute, + rounding_mode=rounding_mode, + input_scale=0.5, + output_scale=0.5) + verify(func, (golden_data, golden_output)) + + def downscale_test(): + for rounding_mode in rounding_modes: + for use_int_compute in [True, False]: + func = get_func(data_shape=(32, ), + data_dtype='int32', + out_dtype="int32", + use_int_compute=use_int_compute, + rounding_mode=rounding_mode, + input_scale=1, + output_scale=16) + + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype('int32') + golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + verify(func, (golden_data, golden_output)) + + # Try negative values + # -8 corresponds to -0.5. For FE_UPWARD, this is 0 + golden_data = np.arange(0, -32, -1).astype('int32') + if use_int_compute == True and rounding_mode == "FE_UPWARD": + golden_output = np.repeat([0, -1, -2], [9, 16, 7]) + else: + golden_output = np.repeat([0, -1, -2], [8, 16, 8]) + verify(func, (golden_data, golden_output)) + + # Try a different scale + for use_int_compute in [True, False]: + func = get_func(data_shape=(32, ), + data_dtype='int32', + out_dtype="int8", + use_int_compute=use_int_compute, + rounding_mode=rounding_mode, + input_scale=1, + output_scale=4) + + # Try positive values + # 2I corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype('int32') + golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8], + [2, 4, 4, 4, 4, 4, 4, 4, 2]) + verify(func, (golden_data, golden_output)) + + # Try negative values + # -8 corresponds to -0.5. For FE_UPWARD, this is 0 + golden_data = np.arange(0, -32, -1).astype('int32') + if use_int_compute == True and rounding_mode == "FE_UPWARD": + golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8], + [3, 4, 4, 4, 4, 4, 4, 4, 1]) + else: + golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8], + [2, 4, 4, 4, 4, 4, 4, 4, 2]) + verify(func, (golden_data, golden_output)) + + def upscale_test(): + for rounding_mode in rounding_modes: + for use_int_compute in [True, False]: + func = get_func(data_shape=(32, ), + data_dtype='int32', + out_dtype="int8", + use_int_compute=use_int_compute, + rounding_mode=rounding_mode, + input_scale=2, + output_scale=1) + + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype('int32') + golden_output = np.multiply(2, golden_data) + verify(func, (golden_data, golden_output)) + + # Try negative values + # -8 corresponds to -0.5. For FE_UPWARD, this is 0 + golden_data = np.arange(0, -32, -1).astype('int32') + golden_output = np.multiply(2, golden_data) + verify(func, (golden_data, golden_output)) + + def saturation_test(): + for rounding_mode in rounding_modes: + for use_int_compute in [True, False]: + func = get_func(data_shape=(16, ), + data_dtype='int32', + out_dtype="int8", + use_int_compute=use_int_compute, + rounding_mode=rounding_mode, + input_scale=0.5, + output_scale=0.5) + golden_data = np.arange(0, 16, 1).astype('int32') + golden_data = np.add(120, golden_data) + output = np.array([120, 121, 122, 123, 124, 125, 126, 127, + 127, 127, 127, 127, 127, 127, 127, 127]) + golden_output = output + verify(func, (golden_data, golden_output)) + + # Try negative numbers + golden_data = np.arange(0, -16, -1).astype('int32') + golden_data = np.add(-120, golden_data) + output = np.array([-120, -121, -122, -123, -124, -125, -126, -127, + -128, -128, -128, -128, -128, -128, -128, -128]) + golden_output = output + verify(func, (golden_data, golden_output)) + + def zero_point_test(): + # Output zero point + for rounding_mode in rounding_modes: + for use_int_compute in [True, False]: + func = get_func(data_shape=(32, ), + data_dtype='int32', + out_dtype="int32", + use_int_compute=use_int_compute, + rounding_mode=rounding_mode, + input_scale=1, + output_scale=16, + output_zero_point=1) + + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype('int32') + golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + golden_output = np.add(1, golden_output) + verify(func, (golden_data, golden_output)) + + # Try negative values + # -8 corresponds to -0.5. For FE_UPWARD, this is 0 + golden_data = np.arange(-32, -64, -1).astype('int32') + if use_int_compute == True and rounding_mode == "FE_UPWARD": + golden_output = np.repeat([-2, -3, -4], [9, 16, 7]) + else: + golden_output = np.repeat([-2, -3, -4], [8, 16, 8]) + golden_output = np.add(1, golden_output) + verify(func, (golden_data, golden_output)) + + # Input zero point + for rounding_mode in rounding_modes: + for use_int_compute in [True, False]: + func = get_func(data_shape=(32, ), + data_dtype='int32', + out_dtype="int32", + use_int_compute=use_int_compute, + rounding_mode=rounding_mode, + input_scale=1, + output_scale=16, + input_zero_point=16) + + # Try positive values + golden_data = np.arange(32, 64, 1).astype('int32') + golden_output = np.repeat([2, 3, 4], [8, 16, 8]) + golden_output = np.subtract(golden_output, 1) + verify(func, (golden_data, golden_output)) + + # Try negative values + golden_data = np.arange(-32, -64, -1).astype('int32') + if use_int_compute == True and rounding_mode == "FE_UPWARD": + golden_output = np.repeat([-2, -3, -4], [9, 16, 7]) + else: + golden_output = np.repeat([-2, -3, -4], [8, 16, 8]) + golden_output = np.subtract(golden_output, 1) + verify(func, (golden_data, golden_output)) + + + + + if __name__ == "__main__": + same_scale_test() + downscale_test() + upscale_test() + saturation_test() + zero_point_test() + + run_tests() + + +def test_qconv2d_requantize(): + def verify(func, goldens): + with relay.build_config(opt_level=0): + graph, lib, params = relay.build(func, "llvm", params=None) + golden_data, golden_weight, golden_output = goldens + mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) + mod.set_input("quantized_data",golden_data) + mod.set_input("weight",golden_weight) + mod.set_input(**params) + mod.run() + res = mod.get_output(0).asnumpy() + np.testing.assert_equal(res, golden_output) + + def get_func(data_shape, data_dtype, weight_shape, weight_dtype, + strides, out_dtype, rounding_mode, use_int_compute): + input_scale = 0.25098 + kernel_scale = 0.501961 + output_scale = 0.501961 + kernel_size = (2, 2) + quantized_data = relay.var("quantized_data", shape=data_shape, + dtype=data_dtype) + quantized_weight = relay.var("weight", shape=weight_shape, + dtype=weight_dtype) + func = relay.op.qnn.conv2d( + quantized_data, quantized_weight, + input_zero_point=0, + kernel_zero_point=0, + kernel_size=kernel_size, + strides=strides, + out_dtype="int32", + data_layout="NCHW", + kernel_layout="OIHW") + func = relay.op.qnn.requantize( + func, + input_zero_point=0, + output_zero_point=0, + input_scale=input_scale*kernel_scale, + output_scale=output_scale, + out_dtype=out_dtype, + rounding_mode=rounding_mode, + use_int_compute=use_int_compute) + + func = relay.Function(relay.analysis.free_vars(func), + func) + func = run_infer_type(func) + func = relay.quantize.rewrite(func) + print(func) + return func + + def run_tests(): + def st1_basic_tests(): + # NCHW input + golden_data = np.array([2, 1, 4, 3, + 5, 4, 2, 3, + 3, 8, 4, 9, + 6, 10, 1, 2]).astype('uint8')\ + .reshape((2, 1, 2, 4)) + # OIHW weight + golden_weight = np.array([2, 4, 6, 8, + 4, 2, 6, 4, + 0, 4, 2, 8]).astype('uint8')\ + .reshape((3, 1, 2, 2)) + + golden_output = np.array([18, 15, 14, + 14, 11, 12, + 12, 10, 10, + 39, 25, 17, + 26, 26, 12, + 31, 11, 14]).astype('uint8')\ + .reshape((2, 3, 1, 3)) + + for use_int_compute in [True, False]: + for rounding_mode in rounding_modes: + func = get_func(data_shape=(2, 1, 2, 4), + data_dtype='uint8', + weight_shape=(3, 1, 2, 2), + weight_dtype='uint8', + strides=(1, 1), + out_dtype="uint8", + rounding_mode=rounding_mode, + use_int_compute=use_int_compute) + verify(func, (golden_data, golden_weight, golden_output)) + + # Check the int8 input type as well + for use_int_compute in [True, False]: + for rounding_mode in rounding_modes: + func = get_func(data_shape=(2, 1, 2, 4), + data_dtype='int8', + weight_shape=(3, 1, 2, 2), + weight_dtype='int8', + strides=(1, 1), + out_dtype="uint8", + rounding_mode=rounding_mode, + use_int_compute=use_int_compute) + verify(func, (golden_data, golden_weight, golden_output)) + + + def st2_basic_tests(): + # NCHW input + golden_data = np.array([2, 1, 4, 3, + 5, 4, 2, 3, + 3, 8, 4, 9, + 6, 10, 1, 2]).astype('uint8')\ + .reshape((2, 1, 2, 4)) + # OIHW weight + golden_weight = np.array([2, 4, 6, 8, + 4, 2, 6, 4, + 0, 4, 2, 8]).astype('uint8')\ + .reshape((3, 1, 2, 2)) + + + golden_output = np.array([18, 14, 14, + 12, 12, 10, + 39, 17, 26, + 12, 31, 14]).astype('uint8')\ + .reshape((2, 3, 1, 2)) + + for use_int_compute in [True, False]: + for rounding_mode in rounding_modes: + func = get_func(data_shape=(2, 1, 2, 4), + data_dtype='uint8', + weight_shape=(3, 1, 2, 2), + weight_dtype='uint8', + strides=(2, 2), + out_dtype="uint8", + rounding_mode=rounding_mode, + use_int_compute=use_int_compute) + verify(func, (golden_data, golden_weight, golden_output)) + + def st2_saturating_tests(): + # Check the output of saturating test + golden_data = np.array([239, 239, 4, 3, + 6, 4, 2, 3, + 3, 8, 4, 9, + 6, 10, 1, 2]).astype('uint8')\ + .reshape((2, 1, 2, 4)) + # OIHW weight + golden_weight = np.array([2, 4, 6, 8, + 4, 2, 6, 4, + 0, 4, 2, 8]).astype('uint8')\ + .reshape((3, 1, 2, 2)) + + + # Check uint8 output clamping + golden_output = np.array([255, 14, 255, + 12, 251, 10, + 39, 17, 26, + 12, 31, 14]).astype('uint8')\ + .reshape((2, 3, 1, 2)) + for use_int_compute in [True, False]: + for rounding_mode in rounding_modes: + func = get_func(data_shape=(2, 1, 2, 4), + data_dtype='uint8', + weight_shape=(3, 1, 2, 2), + weight_dtype='uint8', + strides=(2, 2), + out_dtype="uint8", + rounding_mode=rounding_mode, + use_int_compute=use_int_compute) + verify(func, (golden_data, golden_weight, golden_output)) + + # Check int8 output clamping + golden_output = np.array([127, 14, 127, + 12, 127, 10, + 39, 17, 26, + 12, 31, 14]).astype('uint8')\ + .reshape((2, 3, 1, 2)) + for use_int_compute in [True, False]: + for rounding_mode in rounding_modes: + func = get_func(data_shape=(2, 1, 2, 4), + data_dtype='uint8', + weight_shape=(3, 1, 2, 2), + weight_dtype='uint8', + strides=(2, 2), + out_dtype="int8", + rounding_mode=rounding_mode, + use_int_compute=use_int_compute) + verify(func, (golden_data, golden_weight, golden_output)) + + + # Check that int16 does not clamp + golden_output = np.array([377, 14, 373, + 12, 251, 10, + 39, 17, 26, + 12, 31, 14]).astype('uint16')\ + .reshape((2, 3, 1, 2)) + for use_int_compute in [True, False]: + for rounding_mode in rounding_modes: + func = get_func(data_shape=(2, 1, 2, 4), + data_dtype='uint8', + weight_shape=(3, 1, 2, 2), + weight_dtype='uint8', + strides=(2, 2), + out_dtype="int16", + rounding_mode=rounding_mode, + use_int_compute=use_int_compute) + verify(func, (golden_data, golden_weight, golden_output)) + + #def cast_test(): + # data = relay.var("data", shape=(1, 5), dtype="float32") # --> [63, -4, 76, 0, -5] + # data1 = relay.var("data1", shape=(1, 5), dtype="float32") + # data2 = relay.var("data2", shape=(1, 5), dtype="float32") + # relu1 = relay.op.nn.relu(data) + # relu2 = relay.op.nn.relu(data1) + # relu3 = relay.op.nn.relu(data2) + # func = relay.op.where(relu1, relu2, relu3) + # func = relay.op.nn.relu(func) + # print(func) + # func = relay.Function(relay.analysis.free_vars(func), + # func) + + + # with relay.build_config(opt_level=1): + # graph, lib, params = relay.build(func, "llvm", params=None) + # golden_data = np.array([[63, -4, 76, 0, -5]]) + # golden_data1 = np.array([[1, 2, 3, 4, 5]]) + # golden_data2 = np.array([[6, 7, 8, 9, 10]]) + # mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) + # mod.set_input("data",golden_data) + # mod.set_input("data1",golden_data1) + # mod.set_input("data2",golden_data2) + # mod.set_input(**params) + # mod.run() + # res = mod.get_output(0).asnumpy() + # print(res) + # #np.testing.assert_equal(res, golden_output) + + + if __name__ == "__main__": + st1_basic_tests() + st2_basic_tests() + st2_saturating_tests() + # cast_test() + + run_tests() + +if __name__ == "__main__": + test_requantize() + test_qconv2d_requantize()