diff --git a/include/tvm/relay/attrs/nn_quantize.h b/include/tvm/relay/attrs/qnn.h similarity index 86% rename from include/tvm/relay/attrs/nn_quantize.h rename to include/tvm/relay/attrs/qnn.h index 9cfa87d52507e..2da3e39bae512 100644 --- a/include/tvm/relay/attrs/nn_quantize.h +++ b/include/tvm/relay/attrs/qnn.h @@ -30,9 +30,8 @@ namespace tvm { namespace relay { -// TODO(anijain2305) - Copy of QuantizedConv2DAttrs. Should we inherit? /*! \brief Attribute for quantized conv2d operator */ -struct QuantizedConv2DAttrs : public tvm::AttrsNode { +struct QConv2DAttrs : public tvm::AttrsNode { // Traditional conv2d attributes. Array strides; Array padding; @@ -48,13 +47,8 @@ struct QuantizedConv2DAttrs : public tvm::AttrsNode { // Quantization related attributes. int32_t input_zero_point; int32_t kernel_zero_point; - int32_t output_zero_point; - double input_scale; - double kernel_scale; - double output_scale; - bool use_integer_computation_for_scale_handling; - TVM_DECLARE_ATTRS(QuantizedConv2DAttrs, "relay.attrs.QuantizedConv2DAttrs") { + TVM_DECLARE_ATTRS(QConv2DAttrs, "relay.attrs.QConv2DAttrs") { TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) .describe("Specifies the strides of the convolution."); TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) @@ -88,32 +82,44 @@ struct QuantizedConv2DAttrs : public tvm::AttrsNode { .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "dimensions respectively. Default to be same as input layout."); - - // use 0 bits to indicate none. TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type, set to explicit type under mixed precision setting"); - - TVM_ATTR_FIELD(input_zero_point) .describe("The zero point of the input tensor."); TVM_ATTR_FIELD(kernel_zero_point) .describe("The zero point of the kernel tensor."); + } +}; + + +/*! \brief Attribute for requantize operator */ +struct RequantizeAttrs : public tvm::AttrsNode { + double input_scale; + int32_t input_zero_point; + double output_scale; + int32_t output_zero_point; + bool use_int_compute; + DataType out_dtype; + + TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") { + TVM_ATTR_FIELD(input_zero_point) + .describe("The zero point of the input tensor."); TVM_ATTR_FIELD(output_zero_point) .describe("The zero point of the output tensor."); TVM_ATTR_FIELD(input_scale) .describe("The scale of the input tensor."); - TVM_ATTR_FIELD(kernel_scale) - .describe("The scale of the kernel tensor."); TVM_ATTR_FIELD(output_scale) .describe("The scale of the output tensor."); - TVM_ATTR_FIELD(use_integer_computation_for_scale_handling).set_default(false) + TVM_ATTR_FIELD(use_int_compute).set_default(false) .describe("When true, the integer computation is used to handle output scale"); - - + TVM_ATTR_FIELD(out_dtype) + .set_default(NullValue()) + .describe("Output data type, set to explicit type under mixed precision setting"); } }; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_NN_QUANTIZE_H_ diff --git a/include/tvm/relay/quantize_util.h b/include/tvm/relay/quantize_util.h new file mode 100644 index 0000000000000..bb054fb8fb659 --- /dev/null +++ b/include/tvm/relay/quantize_util.h @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file nnvm/compiler/quantize_util.h + * \brief Utility methods needs for quantized ops that can be shared + */ + +#ifndef TVM_QUANTIZE_UTIL_H +#define TVM_QUANTIZE_UTIL_H + +#include +#include "./base.h" + +namespace tvm { +namespace relay { + +inline bool is_Int8(const DataType& dtype) { + return dtype == Int(8); +} + +inline bool is_UInt8(const DataType& dtype) { + return dtype == UInt(8); +} + + +inline bool is_Int16(const DataType& dtype) { + return dtype == Int(16); +} + +inline bool is_UInt16(const DataType& dtype) { + return dtype == UInt(16); +} + +inline bool is_Int32(const DataType& dtype) { + return dtype == Int(32); +} + +inline bool is_UInt32(const DataType& dtype) { + return dtype == UInt(32); +} + + + +inline bool is_Float32(const DataType& dtype) { + return dtype == Float(32); +} + +inline bool is_quantized_type(const DataType& dtype) { + return is_Int8(dtype) || is_UInt8(dtype) + || is_Int16(dtype) || is_UInt16(dtype); +} + +enum class QuantizeOpType : uint8_t { + Quantize_Requantize, + Dequantize, + Requantize +}; + +inline bool is_valid_quantized_op_input_type(const QuantizeOpType &op_type, const DataType &in_dtype) { + switch(op_type) { + case QuantizeOpType::Quantize_Requantize: + return is_Float32(in_dtype) || is_quantized_type(in_dtype); + case QuantizeOpType ::Dequantize: + return is_quantized_type(in_dtype); + case QuantizeOpType ::Requantize: + return is_Int16(in_dtype) || is_Int32(in_dtype); + default: + return false; + } +} + +inline bool is_valid_quantized_op_output_type(const QuantizeOpType &op_type, const DataType &in_dtype) { + switch(op_type) { + case QuantizeOpType::Quantize_Requantize: + return is_quantized_type(in_dtype); + case QuantizeOpType::Dequantize: + return is_Float32(in_dtype); + default: + return false; + } +} + +inline const int32_t get_qmin(const DataType& dtype) { + if (is_Int8(dtype)) { + return std::numeric_limits::min(); + } else if (is_UInt8(dtype)) { + return std::numeric_limits::min(); + } else if (is_Int16(dtype)) { + return std::numeric_limits::min(); + } else if (is_UInt16(dtype)) { + return std::numeric_limits::min(); + } else if (is_Int32(dtype)) { + return std::numeric_limits::min(); + } else if (is_UInt32(dtype)) { + return std::numeric_limits::min(); + } + LOG(FATAL) << "Type not supported\n"; + return -1; +} + + +inline const int32_t get_qmax(const DataType& dtype) { + if (is_Int8(dtype)) { + return std::numeric_limits::max(); + } else if (is_UInt8(dtype)) { + return std::numeric_limits::max(); + } else if (is_Int16(dtype)) { + return std::numeric_limits::max(); + } else if (is_UInt16(dtype)) { + return std::numeric_limits::max(); + } else if (is_Int32(dtype)) { + return std::numeric_limits::max(); + } else if (is_UInt32(dtype)) { + return std::numeric_limits::max(); + } + LOG(FATAL) << "Type not supported\n"; + return -1; +} + +} // namespace relay +} // namespace tvm +#endif //TVM_QUANTIZE_UTIL_H diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index a27ab1dc50ffd..1d634ef18fc0c 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -26,6 +26,7 @@ from .transform import * from .algorithm import * from . import nn +from . import qnn from . import annotation from . import image from . import vision diff --git a/python/tvm/relay/op/nn/__init__.py b/python/tvm/relay/op/nn/__init__.py index 20bc48d879184..ebabbbcd9d3ad 100644 --- a/python/tvm/relay/op/nn/__init__.py +++ b/python/tvm/relay/op/nn/__init__.py @@ -18,5 +18,4 @@ """Neural network related operators.""" from __future__ import absolute_import as _abs from .nn import * -from . import _quantize from . import _nn diff --git a/python/tvm/relay/op/qnn/__init__.py b/python/tvm/relay/op/qnn/__init__.py new file mode 100644 index 0000000000000..e42063eed26c0 --- /dev/null +++ b/python/tvm/relay/op/qnn/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import +"""Neural network related operators.""" +from __future__ import absolute_import as _abs +from .qnn import * +# from . import _nn diff --git a/python/tvm/relay/op/nn/_make_quantize.py b/python/tvm/relay/op/qnn/_make.py similarity index 94% rename from python/tvm/relay/op/nn/_make_quantize.py rename to python/tvm/relay/op/qnn/_make.py index 2480c99068c4c..b1695629b8f9e 100644 --- a/python/tvm/relay/op/nn/_make_quantize.py +++ b/python/tvm/relay/op/qnn/_make.py @@ -17,4 +17,4 @@ """Constructor APIs""" from ...._ffi.function import _init_api -_init_api("relay.op.nn._quantize._make", __name__) +_init_api("relay.op.qnn._make", __name__) diff --git a/python/tvm/relay/op/nn/_quantize.py b/python/tvm/relay/op/qnn/qnn.py similarity index 60% rename from python/tvm/relay/op/nn/_quantize.py rename to python/tvm/relay/op/qnn/qnn.py index 56a193ac4205f..e695cb89b88a3 100644 --- a/python/tvm/relay/op/nn/_quantize.py +++ b/python/tvm/relay/op/qnn/qnn.py @@ -17,32 +17,31 @@ #pylint: disable=invalid-name, too-many-lines """Neural network operations.""" from __future__ import absolute_import as _abs -from . import _make_quantize - - -def quantized_conv2d(quantized_data, - quantized_weight, - input_zero_point, - kernel_zero_point, - output_zero_point, - input_scale, - kernel_scale, - output_scale, - strides=(1, 1), - padding=(0, 0), - dilation=(1, 1), - groups=1, - channels=None, - kernel_size=None, - data_layout="NCHW", - kernel_layout="OIHW", - out_layout="", - out_dtype=""): +from . import _make + + +def conv2d(quantized_data, + quantized_weight, + input_zero_point, + kernel_zero_point, + strides=(1, 1), + padding=(0, 0), + dilation=(1, 1), + groups=1, + channels=None, + kernel_size=None, + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="", + out_dtype="int32"): r"""Quantized 2D convolution. This operator takes the quantized_weight as the convolution kernel - and convolves it with quantized_data to produce an output. - + and convolves it with quantized_data to produce an output quantized tensor. + The scale of the output quantized tensor is the prodcut of the weight_scale + and input_scale of the input quantized tensors. The zero point of the output + quantized tensor is 0. By default, the dtype of output is int32. Please also + see Requantize operator to understand the dtype scaling back to (u)int8. In the default case, where the data_layout is `NCHW` and kernel_layout is `OIHW`, conv2d takes in @@ -71,24 +70,12 @@ def quantized_conv2d(quantized_data, quantized_weight : tvm.relay.Expr The quantized_weight expressions. - input_scale: float - The float scalar to scale the quantized_data int8 values back to FP32. - - kernel_scale: float - The float scalar to scale the quantized_kernel int8 values back to FP32. - - output_scale: float - The float scalar to scale the quantized_output int8 values back to FP32. - - input_zero_point: int + input_zero_point: int The zero point of the quantized_data distribution. - kernel_zero_point: int + kernel_zero_point: int The zero point of the quantized_kernel distribution. - output_zero_point: int - The zero point of the quantized_output distribution. - strides : tuple of int, optional The strides of convolution. @@ -124,10 +111,54 @@ def quantized_conv2d(quantized_data, result : tvm.relay.Expr The computed result. """ - return _make_quantize.quantized_conv2d(quantized_data, quantized_weight, - input_zero_point, kernel_zero_point, output_zero_point, - input_scale, kernel_scale, output_scale, + return _make.conv2d(quantized_data, quantized_weight, + input_zero_point, kernel_zero_point, strides, padding, dilation, groups, channels, kernel_size, - data_layout, kernel_layout, out_layout, - out_dtype) + data_layout, kernel_layout, out_layout, out_dtype) + +def requantize(input_data, input_zero_point, input_scale, output_zero_point, + output_scale, out_dtype="int32", use_int_compute=False): + r"""Requantized operator. + + The requantize operator converts one quantized tensor to another quantized + tensor. For the output tensor, we are provided with output scale and zero + point. The computation looks like this + + Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input) + + The above computation can be done in floating point as the scales are in + FP32. Alternatively, we can approximate floating point with fixed point + computation. This is controlled by use_int_compute. + + Parameters + ---------- + quantized_data : tvm.relay.Expr + The input quantized_data to the operator. + + input_scale: float + The float scalar to scale the quantized_data int8 values back to FP32. + + output_scale: float + The float scalar to scale the quantized_output int8 values back to FP32. + + input_zero_point: int + The zero point of the quantized_data distribution. + + output_zero_point: int + The zero point of the quantized_output distribution. + + out_dtype : str, optional + Specifies the output quantized_data type for mixed precision conv2d. + + use_int_compute : bool, optional + Use fully integer computation for requantizing. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + return _make.requantize(input_data, input_zero_point, input_scale, + output_zero_point, output_scale, out_dtype, + use_int_compute) diff --git a/python/tvm/relay/quantize/__init__.py b/python/tvm/relay/quantize/__init__.py index d51ae80e9fb15..8da4e7953566f 100644 --- a/python/tvm/relay/quantize/__init__.py +++ b/python/tvm/relay/quantize/__init__.py @@ -19,5 +19,5 @@ from __future__ import absolute_import as _abs from .quantize import * -from .quantize_rewrite import * +from .rewrite import * from ._annotate import register_annotate_function diff --git a/python/tvm/relay/quantize/quantize_rewrite.py b/python/tvm/relay/quantize/rewrite.py similarity index 94% rename from python/tvm/relay/quantize/quantize_rewrite.py rename to python/tvm/relay/quantize/rewrite.py index c2099b65298e6..89429e522115a 100644 --- a/python/tvm/relay/quantize/quantize_rewrite.py +++ b/python/tvm/relay/quantize/rewrite.py @@ -21,7 +21,7 @@ from . import _quantize from .. import expr as _expr -def quantize_rewrite(expr): +def rewrite(expr): """ Rewrites the high-level quantized ops into low-level exisiting Relay ops. @@ -35,4 +35,4 @@ def quantize_rewrite(expr): expr : tvm.relay.Expr The output expression. """ - return _quantize.quantize_rewrite(expr) + return _quantize.rewrite(expr) diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 97cba79640005..ed50080346bb2 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include "../../pass/alter_op_layout.h" @@ -35,6 +36,7 @@ namespace relay { // relay.nn.conv2d TVM_REGISTER_NODE_TYPE(Conv2DAttrs); +template bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, @@ -46,7 +48,7 @@ bool Conv2DRel(const Array& types, static const Layout kNCHW("NCHW"); static const Layout kOIHW("OIHW"); - const Conv2DAttrs* param = attrs.as(); + const auto param = attrs.as(); CHECK(param != nullptr); const Layout in_layout(param->data_layout); const Layout kernel_layout(param->kernel_layout); @@ -191,7 +193,7 @@ with the layer input to produce a tensor of outputs. .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(2) -.add_type_rel("Conv2D", Conv2DRel) +.add_type_rel("Conv2D", Conv2DRel) .set_attr("FInferCorrectLayout", Conv2DInferCorrectLayout); @@ -701,7 +703,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc") .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(10) -.add_type_rel("Conv2D", Conv2DRel) +.add_type_rel("Conv2D", Conv2DRel) .set_attr("FInferCorrectLayout", Conv2DInferCorrectLayout); @@ -751,7 +753,7 @@ RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc") .add_argument("data", "Tensor", "The input tensor.") .add_argument("weight", "Tensor", "The weight tensor.") .set_support_level(10) -.add_type_rel("Conv2D", Conv2DRel) +.add_type_rel("Conv2D", Conv2DRel) .set_attr("FInferCorrectLayout", Conv2DInferCorrectLayout); @@ -896,6 +898,66 @@ Expr MakeDeformableConv2D(Expr data, TVM_REGISTER_API("relay.op.nn._make.deformable_conv2d") .set_body_typed(MakeDeformableConv2D); +// relay.op.qnn.conv2d +TVM_REGISTER_NODE_TYPE(QConv2DAttrs); + +// Positional relay function to create quantized conv2d operator +// used by frontend FFI. +Expr MakeQConv2D(Expr quantized_data, + Expr quantized_weight, + int32_t input_zero_point, + int32_t kernel_zero_point, + Array strides, + Array padding, + Array dilation, + int groups, + IndexExpr channels, + Array kernel_size, + std::string data_layout, + std::string kernel_layout, + std::string out_layout, + DataType out_dtype) { + auto attrs = make_node(); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->dilation = std::move(dilation); + attrs->groups = groups; + attrs->channels = std::move(channels); + attrs->kernel_size = std::move(kernel_size); + attrs->data_layout = std::move(data_layout); + attrs->kernel_layout = std::move(kernel_layout); + attrs->out_layout = std::move(out_layout); + attrs->out_dtype = std::move(out_dtype); + attrs->input_zero_point = std::move(input_zero_point); + attrs->kernel_zero_point = std::move(kernel_zero_point); + static const Op& op = Op::Get("qnn.conv2d"); + return CallNode::make(op, {quantized_data, quantized_weight}, Attrs(attrs), {}); +} + +RELAY_REGISTER_OP("qnn.conv2d") +.describe(R"code(2D quantized convolution layer. + +This operator creates a quantized convolution kernel that is convolved +with the quantized input to produce a tensor of quantized outputs. The +operator is further lowered to existing set of Relay operators. + +- **quantized_data**: This depends on the `layout` parameter. Input is 4D array of shape + (batch_size, in_channels, height, width) if `layout` is `NCHW`. +- **quantized_weight**: (channels, in_channels, kernel_size[0], kernel_size[1]) +- **quantized_out**: This depends on the `layout` parameter. Output is 4D array of shape + (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. + +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.QConv2DAttrs") +.set_num_inputs(2) +.add_argument("quantized_data", "Tensor", "The quantized input quantized_data tensor.") +.add_argument("quantized_weight", "Tensor", "The quantized quantized_weight tensor.") +.set_support_level(10) +.add_type_rel("QConv2D", Conv2DRel); + +TVM_REGISTER_API("relay.op.qnn._make.conv2d") +.set_body_typed(MakeQConv2D); + } // namespace relay } // namespace tvm diff --git a/src/relay/op/nn/quantized_convolution.cc b/src/relay/op/nn/quantized_convolution.cc deleted file mode 100644 index af243e237d285..0000000000000 --- a/src/relay/op/nn/quantized_convolution.cc +++ /dev/null @@ -1,196 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2018 by Contributors - * \file quantized_convolution.cc - * \brief Quantized convolution operators - */ - -#include -#include -#include -#include - -namespace tvm { -namespace relay { - -TVM_REGISTER_NODE_TYPE(QuantizedConv2DAttrs); - -// TODO(anijain2305) - Copy of Conv2D Rel. Should be share? -// Need separation of header/implementation. -bool QuantizeConv2DRel(const Array& types, - int num_inputs, - const Attrs& attrs, - const TypeReporter& reporter) { - CHECK_EQ(types.size(), 3); - const auto* data = types[0].as(); - const auto* weight = types[1].as(); - if (data == nullptr) return false; - static const Layout kNCHW("NCHW"); - static const Layout kOIHW("OIHW"); - - const QuantizedConv2DAttrs* param = attrs.as(); - CHECK(param != nullptr); - DataType out_dtype = param->out_dtype; - CHECK_NE(out_dtype, NullValue()) - << "Quantized convolution out_dtype has to be passed\n"; - const Layout in_layout(param->data_layout); - const Layout kernel_layout(param->kernel_layout); - - const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW); - CHECK(trans_in_layout.defined()) - << "Conv only support input layouts that are convertible from NCHW." - << " But got " << in_layout; - - const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW); - CHECK(trans_kernel_layout.defined()) - << "Conv only support kernel layouts that are convertible from OIHW." - << " But got "<< kernel_layout; - - Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); - const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW); - CHECK(trans_out_layout.defined()) - << "Conv only support output layouts that are convertible from NCHW." - << " But got " << out_layout; - - Array dshape_nchw = trans_in_layout.ForwardShape(data->shape); - - IndexExpr channels, dilated_ksize_y, dilated_ksize_x; - // infer weight if the kernel_size and channels are defined - if (param->kernel_size.defined() && param->channels.defined()) { - CHECK_EQ(param->kernel_size.size(), 2); - CHECK_EQ(param->dilation.size(), 2); - Array wshape( - {param->channels, - dshape_nchw[1] / param->groups, - param->kernel_size[0], - param->kernel_size[1]}); - wshape = trans_kernel_layout.BackwardShape(wshape); - channels = param->channels; - dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; - dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; - // assign result to reporter - reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype)); - } else { - // use weight to infer the conv shape. - if (weight == nullptr) return false; - auto wshape = trans_kernel_layout.ForwardShape(weight->shape); - if (param->kernel_size.defined()) { - CHECK_EQ(param->kernel_size.size(), 2); - // check the size - CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && - reporter->AssertEQ(param->kernel_size[1], wshape[3])) - << "Conv2D: shape of weight is inconsistent with kernel_size, " - << " kernel_size=" << param->kernel_size - << " wshape=" << wshape; - } - if (param->channels.defined()) { - CHECK(reporter->AssertEQ(param->channels, wshape[0])) - << "Conv2D: shape of weight is inconsistent with channels, " - << " channels=" << param->channels - << " wshape=" << wshape; - } - CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[1])); - channels = wshape[0]; - dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; - dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; - } - // dilation - Array oshape({dshape_nchw[0], channels, 0, 0}); - - oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y) / param->strides[0] + 1); - oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x) / param->strides[1] + 1); - if (out_dtype.bits() == 0) { - out_dtype = data->dtype; - } - oshape = trans_out_layout.BackwardShape(oshape); - // assign output type - reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); - return true; -} - - -// Positional relay function to create quantized conv2d operator -// used by frontend FFI. -Expr MakeQuantizeConv2D(Expr quantized_data, - Expr quantized_weight, - int32_t input_zero_point, - int32_t kernel_zero_point, - int32_t output_zero_point, - double input_scale, - double kernel_scale, - double output_scale, - Array strides, - Array padding, - Array dilation, - int groups, - IndexExpr channels, - Array kernel_size, - std::string data_layout, - std::string kernel_layout, - std::string out_layout, - DataType out_dtype) { - auto attrs = make_node(); - attrs->strides = std::move(strides); - attrs->padding = std::move(padding); - attrs->dilation = std::move(dilation); - attrs->groups = groups; - attrs->channels = std::move(channels); - attrs->kernel_size = std::move(kernel_size); - attrs->data_layout = std::move(data_layout); - attrs->kernel_layout = std::move(kernel_layout); - attrs->out_layout = std::move(out_layout); - attrs->out_dtype = std::move(out_dtype); - attrs->input_zero_point = std::move(input_zero_point); - attrs->kernel_zero_point = std::move(kernel_zero_point); - attrs->output_zero_point = std::move(output_zero_point); - attrs->input_scale = std::move(input_scale); - attrs->kernel_scale = std::move(kernel_scale); - attrs->output_scale = std::move(output_scale); - static const Op& op = Op::Get("nn_quantized.quantized_conv2d"); - return CallNode::make(op, {quantized_data, quantized_weight}, Attrs(attrs), {}); -} - -RELAY_REGISTER_OP("nn_quantized.quantized_conv2d") -.describe(R"code(2D quantized convolution layer. - -This operator creates a quantized convolution kernel that is convolved -with the quantized input to produce a tensor of quantized outputs. The -operator is further lowered to existing set of Relay operators. - -- **quantized_data**: This depends on the `layout` parameter. Input is 4D array of shape - (batch_size, in_channels, height, width) if `layout` is `NCHW`. -- **quantized_weight**: (channels, in_channels, kernel_size[0], kernel_size[1]) -- **quantized_out**: This depends on the `layout` parameter. Output is 4D array of shape - (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. - -)code" TVM_ADD_FILELINE) -.set_attrs_type_key("relay.attrs.QuantizedConv2DAttrs") -.set_num_inputs(2) -.add_argument("quantized_data", "Tensor", "The quantized input quantized_data tensor.") -.add_argument("quantized_weight", "Tensor", "The quantized quantized_weight tensor.") -.set_support_level(10) -.add_type_rel("QuantizeConv2D", QuantizeConv2DRel); - -TVM_REGISTER_API("relay.op.nn._quantize._make.quantized_conv2d") -.set_body_typed(MakeQuantizeConv2D); - -} // namespace relay -} // namespace tvm diff --git a/src/relay/op/nn/requantize.cc b/src/relay/op/nn/requantize.cc new file mode 100644 index 0000000000000..813ab4d05f134 --- /dev/null +++ b/src/relay/op/nn/requantize.cc @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2018 by Contributors + * \file requantize.cc + * \brief Quantized convolution operators + */ + +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(RequantizeAttrs); + + +bool RequantizeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + const auto input_dtype = data->dtype; + CHECK(is_valid_quantized_op_input_type(QuantizeOpType::Requantize, input_dtype)) + << "Input type should be a quantized type (u)int8 or (u)int16 but was " << input_dtype; + + const Array oshape = data->shape; + // assign output type + const RequantizeAttrs* param = attrs.as(); + reporter->Assign(types[1], TensorTypeNode::make(oshape, param->out_dtype)); + return true; +} + +// Positional relay function to create quantized conv2d operator +// used by frontend FFI. +Expr MakeRequantize(Expr data, + int32_t input_zero_point, + double input_scale, + int32_t output_zero_point, + double output_scale, + DataType out_dtype, + bool use_int_compute) { + auto attrs = make_node(); + attrs->out_dtype = std::move(out_dtype); + attrs->input_zero_point = std::move(input_zero_point); + attrs->output_zero_point = std::move(output_zero_point); + attrs->input_scale = std::move(input_scale); + attrs->output_scale = std::move(output_scale); + attrs->use_int_compute = std::move(use_int_compute); + static const Op& op = Op::Get("qnn.requantize"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +RELAY_REGISTER_OP("qnn.requantize") +.describe(R"code(Requantize operator. + +FIXME +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.RequantizeAttrs") +.set_num_inputs(1) +.add_argument("data", "Tensor", "The quantized input tensor.") +.set_support_level(10) +.add_type_rel("Requantize", RequantizeRel); + +TVM_REGISTER_API("relay.op.qnn._make.requantize") +.set_body_typed(MakeRequantize); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 6047593d7f421..65226241aa62e 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -399,6 +399,25 @@ inline Expr Conv2D(Expr data, return CallNode::make(op, {data, weight}, Attrs(attrs), {}); } +inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) { + static const Op& op = Op::Get("where"); + return CallNode::make(op, {condition, x, y}); +} + +inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) { + static const Op& op = Op::Get("greater_equal"); + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); +} + +inline Expr Full(Expr fill_value, + Array shape, + DataType dtype) { + auto attrs = make_node(); + attrs->shape = std::move(shape); + attrs->dtype = std::move(dtype); + static const Op& op = Op::Get("full"); + return CallNode::make(op, {fill_value}, Attrs(attrs), {}); +} Expr MakeConcatenate(Expr data, int axis); Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides); diff --git a/src/relay/pass/quantize_rewrite.cc b/src/relay/pass/quantize_rewrite.cc index a64c3e758a2bb..bd98602a72ff0 100644 --- a/src/relay/pass/quantize_rewrite.cc +++ b/src/relay/pass/quantize_rewrite.cc @@ -25,161 +25,309 @@ #include #include -#include +#include +#include #include "pattern_util.h" namespace tvm { namespace relay { -Expr ConvolveQuantizedTensors(const Expr& quantized_data, - const Expr& quantized_kernel, const QuantizedConv2DAttrs*& param) { - // TODO (janimesh) - Who should decide the accumulation dtype? - if (param->input_zero_point == 0 && param->kernel_zero_point == 0) { - Expr int8_conv = Conv2D(quantized_data, - quantized_kernel, - param->strides, - param->padding, - param->dilation, - param->groups, - param->channels, - param->kernel_size, - param->data_layout, - param->kernel_layout, - param->out_layout, - Int(32)); - return int8_conv; + +// Lowering of qnn.requantize op +void GetFixedPointMultiplierShift(double double_multiplier, + int32_t* fixed_point_multiplier, int* shift, + const DataType& idtype) { + + int acc_dtype_bits = idtype.bits(); + + if (double_multiplier == 0.) { + *fixed_point_multiplier = 0; + *shift = 0; + return; } - LOG(FATAL) << "Only symmetric quantization supported"; - return Expr(); // to hide the warning. + const double q = std::frexp(double_multiplier, shift); + auto q_fixed = static_cast(std::round(q * (1ll << (acc_dtype_bits - 1)))); + CHECK_LE(q_fixed, (1ll << (acc_dtype_bits - 1))); + if (q_fixed == (1ll << (acc_dtype_bits - 1))) { + q_fixed /= 2; + ++*shift; + } + CHECK_LE(q_fixed, std::numeric_limits::max()); + *fixed_point_multiplier = static_cast(q_fixed); } -Expr ScaleHandling(const Expr& convolved_tensor, - const QuantizedConv2DAttrs*& param) { - // The scale handling can be done in many ways. - // 1) Floating point handling - // Here we can multiply the scale to the convolved_tensor, round to nearest - // integer and then cast back to int32. - // 2) Integer only scale handling - // Here, the computation is converted to a fixed point computation by - // computing output multiplier and shift. This is useful, if the target - // device does not support/have very expensive floating point computations. - - if (param->use_integer_computation_for_scale_handling == false) { - double multiplier = (param->input_scale * param->kernel_scale) / - param->output_scale; - auto scalar_multiplier = MakeConstantScalar(Float(32), multiplier); - auto casted_convolved_tensor = Cast(convolved_tensor, Float(32)); - auto scaled_fp32_tensor = Multiply(casted_convolved_tensor, scalar_multiplier); - auto scaled_rounded_fp32_tensor = Round(scaled_fp32_tensor); - auto scaled_tensor = Cast(scaled_rounded_fp32_tensor, Int(32)); - return scaled_tensor; +Expr MultiplyByIntegerMuliplier(const Expr& convolved_tensor, + const int32_t fixed_point_multiplier, const int left_shift, + const RequantizeAttrs*& param, const DataType& idtype, + const Array& out_shape) { + // TODO (janimesh) - How to add the overflow checks here. TFLite code snippet is + // bool overflow = a == b && a == std::numeric_limits::min(); + // return overflow ? std::numeric_limits::max() : .....;/ + + // The calculations are done in upcast of idtype to retain precision. + int acc_dtype_bits = idtype.bits(); + DataType up_idtype = Int(2 * acc_dtype_bits); + + auto tensor = convolved_tensor; + // Typically the left_shift will be 0 if the original scale is > 0.5. + if (left_shift != 0) { + tensor = Multiply(tensor, MakeConstantScalar(idtype, 1 << left_shift)); } - LOG(FATAL) << "Only floating point scale handling is supported for now."; - return Expr(); // to hide the warning. + + // Upcast the computation to Int64 and multiply the multiplier. + Expr scalar = MakeConstantScalar(up_idtype, fixed_point_multiplier); + auto multiplied_t = Multiply(Cast(tensor, up_idtype), scalar); + + // Since, we are performing fixed point computation. We are only interested in + // higher 16/32 bits. But before that, we also need to perform rounding. + // This is fixed point rounding. So, the rounder add scalar depends if the + // input is positive. + auto zero = MakeConstantScalar(up_idtype, 0); + auto pos_threshold = MakeConstantScalar(up_idtype, + 1ll << (acc_dtype_bits - 2)); + auto neg_threshold = MakeConstantScalar(up_idtype, + (1 - (1ll << (acc_dtype_bits - 2)))); + auto pos_rounder = Full(pos_threshold, out_shape, up_idtype); + auto neg_rounder = Full(neg_threshold, out_shape, up_idtype); + auto rounding_scalar = Where(GreaterEqual(multiplied_t, zero), pos_rounder, neg_rounder); + auto rounded_tensor = Add(multiplied_t, rounding_scalar); + + // Perform right shift to get the first 16/32 bits. + // The result is first doubled and the first 15/31 bits are obtained. This is + // done by just right shifting the result by 15/31 bits. + auto right_shift_scalar = MakeConstantScalar(up_idtype, (acc_dtype_bits - 1)); + auto scaled_t = RightShift(rounded_tensor, right_shift_scalar); + auto q_imin = get_qmin(idtype); + auto q_imax = get_qmax(idtype); + auto integer_multiplied_t = Cast(Clip(scaled_t, q_imin, q_imax), + idtype); + return integer_multiplied_t; +} + +Expr ShiftByIntegerShift(const Expr& multiplied_t, + const int& exponent, const RequantizeAttrs*& param, + const DataType& idtype, const Array& out_shape) { + CHECK_GE(exponent, 0); + int acc_dtype_bits = idtype.bits(); + CHECK_LE(exponent, (acc_dtype_bits - 1)); + + // We need to perform rounding. The rounding here is closest to the power + // of 2. The exponent basically represents the decimal point. We need to round + // at the decimal point. + auto tensor = multiplied_t; + if (exponent != 0) { + auto pos_rounder = MakeConstantScalar(idtype, (1ll << (exponent - 1))); + auto neg_rounder = MakeConstantScalar(idtype, (1ll << (exponent - 1)) - 1); + auto pos_rounder_t = Full(pos_rounder, out_shape, idtype); + auto neg_rounder_t = Full(neg_rounder, out_shape, idtype); + + auto zero = MakeConstantScalar(idtype, 0); + auto zero_t = Full(zero, out_shape, idtype); + auto round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, + neg_rounder_t); + tensor = Add(tensor, round_scalar); + } + + // Right shift by exponent to approximate the division. + auto scaled_t = RightShift(tensor, + MakeConstantScalar(idtype, exponent)); + return scaled_t; +} + + +/* + * Requantization using only integer computation. Here, the computation is + * converted to a fixed point computation by computing output multiplier and + * shift. This is useful, if the target device does not support/have very + * expensive floating point computations. + * + * Original compuation is scale_fp32 * quantized_tensor. To convert into + * integer computation, the multiplication with fp32 scalar can be replaced by + * multiplication with an int value and then right shifting the result. This + * approximates the floating point computation with a fixed point computation. + * + * The whole computaition this can be broken down into following steps + * 1) Calculate the integer multiplier and integer shift. + * 2) Multiply the integer multiplier with quantized tensor. + * 3) Right shift the result. + * + * The only thing complicating the above computations is the tedious approach of + * handling rounding. + */ +Expr RequantizeInt(const Expr& convolved_tensor, + const RequantizeAttrs*& param, const DataType& idtype, + const Array& out_shape) { + + double double_multiplier = param->input_scale/param->output_scale; + // 1) Calculating the integer multiplier and integer shift + int32_t fixed_point_multiplier; + int shift; + GetFixedPointMultiplierShift(double_multiplier, &fixed_point_multiplier, + &shift, idtype); + + // 2) Multiply the integer multiplier + int left_shift = shift > 0 ? shift : 0; + int right_shift = shift > 0 ? 0 : -shift; + auto multiplied_t = MultiplyByIntegerMuliplier(convolved_tensor, + fixed_point_multiplier, left_shift, param, idtype, out_shape); + + // 3) Divide by the denominator or right shift the result. + auto scaled_int32_t = ShiftByIntegerShift(multiplied_t, + right_shift, param, idtype, out_shape); + + // 4) Clip to the out_dtype min/max. + auto q_min = std::max(get_qmin(param->out_dtype), get_qmin(idtype)); + auto q_max = std::min(get_qmax(param->out_dtype), get_qmax(idtype)); + auto clipped_t = Clip(scaled_int32_t, q_min, q_max); + auto requantized_output = Cast(clipped_t, param->out_dtype); + return requantized_output; } -Expr ReQuantize(const Expr& scaled_output, - const QuantizedConv2DAttrs*& param) { - Expr requantized_output = Cast(scaled_output, param->out_dtype); +/* + * Requantization using floating computation. Here we can multiply the scale to + * the convolved_tensor, round to nearest integer and then cast back to int32. + */ +Expr RequantizeFloat(const Expr& convolved_tensor, + const RequantizeAttrs*& param, const DataType& idtype, + const Array& out_shape) { + double double_multiplier = param->input_scale/param->output_scale; + auto scalar_multiplier = MakeConstantScalar(Float(32), double_multiplier); + + // Multiply the convolved tensor with the new scale. + auto casted_t = Cast(convolved_tensor, Float(32)); + auto multiplied_t = Round(Multiply(casted_t, scalar_multiplier)); + auto q_imin = get_qmin(idtype); + auto q_imax = get_qmax(idtype); + auto scaled_int32_t = Cast(Clip(multiplied_t, q_imin, q_imax), + idtype); + + // Clip to the out_dtype min/max. + // Clip limits must be smaller than the dtype of the input tensor. + auto q_min = std::max(get_qmin(param->out_dtype), get_qmin(idtype)); + auto q_max = std::min(get_qmax(param->out_dtype), get_qmax(idtype)); + auto clipped_t = Clip(scaled_int32_t, q_min, q_max); + auto requantized_output = Cast(clipped_t, param->out_dtype); return requantized_output; } -Expr QuantizedConv2DForwardRewrite(const Call& ref_call, - const Array& new_args, - const NodeRef& ctx) { +/* + * Lowering of the requantize operation. The requantize operator converts one + * quantized tensor to another quantized tensor. For the output tensor, we are + * provided with output scale and zero point. The computation looks like this + * + * Q_output = zp_output + (scale_input)/(scale_ouptut) * (Q_input - zp_input) + * + * The above computation can be done in floating point as the scales are in + * FP32. Alternatively, we can approximate floating point with fixed point + * computation. This is controlled by use_int_compute. + */ +Expr RequantizeForwardRewrite(const Call& ref_call, + const Array& new_args, const NodeRef& ctx) { + CHECK_EQ(new_args.size(), 1); + Expr quantized_data = new_args[0]; + const auto* param = ref_call->attrs.as(); + + // Find output shape. + Array out_shape; + auto ref_call_t = ref_call->checked_type(); + auto output_tt = ref_call_t.as(); + CHECK(output_tt != nullptr) << "Type information missing." + << " Please run infer_type pass."; + out_shape = output_tt->shape; + + // Find input dtype. + auto ref_input_t = ref_call->args[0]->checked_type(); + auto input_tt = ref_input_t.as(); + CHECK(input_tt != nullptr) << "Type information missing." + << " Please run infer_type pass."; + const auto input_dtype = input_tt->dtype; + + // Check for current quantization support. + CHECK_EQ(param->input_zero_point, 0) + << "Encountered non-zero zero point." + << " Only symmetric quantization supported for now."; + CHECK_EQ(param->output_zero_point, 0) + << "Encountered non-zero zero point." + << " Only symmetric quantization supported for now."; + + if (param->use_int_compute) { + return RequantizeInt(quantized_data, param, input_dtype, out_shape); + } else { + return RequantizeFloat(quantized_data, param, input_dtype, out_shape); + } +} + + +RELAY_REGISTER_OP("qnn.requantize") +.set_attr("FQuantizeForwardRewrite", RequantizeForwardRewrite); + +// Lowering of qnn.conv2d op + +/* + * Lowering of the quantized_convolution. + * + * A quantized tensor is represented in following manner + * A = scale_a x (QA - zp_A) + * where QA is quantized tensor, scale_a and zp_A are quantizations params. + * + * Quantized convlution convolves two quantized tensors and returns a quantized + * tensor of default dtype of int32, with scale equaling to the product of + * scales of input tensors, and a zero point of zero. + * + * For symmetric quantization, the zp_* for all tensors is 0. So, the lowering + * of qnn.conv2d is + * + * QA(n, ic, oh + r, ow + s) (conv) QW(oc, ic, r, s) + * + * For asymmetric computation, we can perform similar unrolling. We can find + * more details at + * https://discuss.tvm.ai/t/tf-lite-quantized-conv2d-operator-conversion/2651/8?u=janimesh + */ +Expr QConv2DForwardRewrite(const Call& ref_call, + const Array& new_args, const NodeRef& ctx) { CHECK_EQ(new_args.size(), 2); Expr quantized_data = new_args[0]; Expr quantized_kernel = new_args[1]; - const auto* param = ref_call->attrs.as(); - CHECK_EQ(param->input_zero_point, 0) << "Only symmetric support yet"; - CHECK_EQ(param->kernel_zero_point, 0) << "Only symmetric support yet"; - CHECK_EQ(param->output_zero_point, 0) << "Only symmetric support yet"; - // TODO(janimesh) - The out_dtype should be something else, like "int32". - Expr int8_conv = Conv2D(quantized_data, - quantized_kernel, - param->strides, - param->padding, - param->dilation, - param->groups, - param->channels, - param->kernel_size, - param->data_layout, - param->kernel_layout, - param->out_layout, - Int(32)); - // TODO(janimesh) - The out_dtype should come from outside.. - int8_conv = Cast(int8_conv, param->out_dtype); - // TODO(janimesh) - Look at the literature and use the right scale - // calculations. - return int8_conv; + const auto* param = ref_call->attrs.as(); + + Array out_shape; + auto ref_call_t = ref_call->checked_type(); + auto output_tt = ref_call_t.as(); + CHECK(output_tt != nullptr) << "Type information missing." + << " Please run infer_type pass."; + out_shape = output_tt->shape; // Check for current quantization support. - CHECK_EQ(param->input_zero_point, 0) + CHECK_EQ(param->input_zero_point, 0) << "Encountered non-zero zero point." << " Only symmetric quantization supported for now."; CHECK_EQ(param->kernel_zero_point, 0) << "Encountered non-zero zero point." << " Only symmetric quantization supported for now."; - CHECK_EQ(param->output_zero_point, 0) - << "Encountered non-zero zero point." - << " Only symmetric quantization supported for now."; - CHECK_EQ(param->use_integer_computation_for_scale_handling, false) - << "Currently floating point computation is used for scale handling. " - << "Please switch to False if HW supports floating point arithmetic"; - - // Lowering of the quantized_convolution. - // - // For FP32, the conv output is - // C = conv(A, W) - // or, C(n, oc, oh, ow) = A(n, ic, oh + r, ow + s) * W(oc, ic, r, s) - // where, ic, r, s are reduce axis. - // - // For quantized convolution, each tensor is represented in quantized format - // A = scale_a x (QA - zp_A) - // where QA is quantized tensor, scale_a and zp_A are quantizations params. - // - // For symmetric quantization, the zp_* for all tensors is 0. - // So, the quantized_convolution becomes - // - // scale_c * QC(n, oc, oh, ow) = - // scale_a * QA(n, ic, oh + r, ow + s) x - // scale_w * QW(oc, ic, r, s) - // - // So, to get the quantized tensor C, the computation is - // - // QC(n, oc, oh, ow) = (scale_a * scale_w)/scale_c x - // QA(n, ic, oh + r, ow + s) x QW(oc, ic, r, s) - // - // or, - // QC = K * conv(QA, QB) - // - // For asymmetric computation, we can perform similar unrolling. We can find - // more details at - // https://discuss.tvm.ai/t/tf-lite-quantized-conv2d-operator-conversion/2651/8?u=janimesh - - // The above computation is arranged in following functions - // 1) ConvolveQuantizedTensors - // a) For symmetric, conv(QA, QB). - // b) For asymmetric, it involves 4 terms. - // 2) ScaleHandling - // a) Takes convolved output and scales it. - // b) Can support both float and integer computation. - // 3) Requantize - // a) Converts the intermediate dtype back to int8. - Expr convolved_tensor = ConvolveQuantizedTensors(quantized_data, - quantized_kernel, - param); - Expr scaled_output = ScaleHandling(convolved_tensor, param); - Expr requantized_output = ReQuantize(scaled_output, param); - // TODO(janimesh) - Look at the literature and use the right scale - // calculations. - return requantized_output; + + if (param->input_zero_point == 0 && param->kernel_zero_point == 0) { + Expr int8_conv = Conv2D(quantized_data, + quantized_kernel, + param->strides, + param->padding, + param->dilation, + param->groups, + param->channels, + param->kernel_size, + param->data_layout, + param->kernel_layout, + param->out_layout, + param->out_dtype); + return int8_conv; + } + LOG(FATAL) << "Only symmetric quantization supported"; + return Expr(); // to hide the warning. } -RELAY_REGISTER_OP("nn_quantized.quantized_conv2d") -.set_attr("FQuantizeForwardRewrite", QuantizedConv2DForwardRewrite); +RELAY_REGISTER_OP("qnn.conv2d") +.set_attr("FQuantizeForwardRewrite", QConv2DForwardRewrite); -TVM_REGISTER_API("relay._quantize.quantize_rewrite") +TVM_REGISTER_API("relay._quantize.rewrite") .set_body_typed([](const Expr& e) { Expr ret = ForwardRewrite(e, "FQuantizeForwardRewrite", nullptr, nullptr); return ret; diff --git a/tests/python/unittest/test_quantized_ops.py b/tests/python/unittest/test_quantized_ops.py index d0b8cc74ffa10..4433e43d4fffb 100644 --- a/tests/python/unittest/test_quantized_ops.py +++ b/tests/python/unittest/test_quantized_ops.py @@ -16,37 +16,241 @@ # under the License. import tvm +import numpy as np from tvm import relay from tvm.relay.testing import create_workload +from tvm.contrib import graph_runtime -def test_quantized_conv2d(): - quantized_data = relay.var("quantized_data", shape=(1, 128, 16, 16), dtype='int8') - quantized_weight = relay.var("weight", shape=(64, 128, 3, 3), dtype='int8') - quantized_output = relay.op.nn._quantize.quantized_conv2d( \ - quantized_data, quantized_weight, - input_zero_point=0, - kernel_zero_point=0, - output_zero_point=0, - input_scale=0.5, - kernel_scale=0.5, - output_scale=0.5, - channels=64, - kernel_size=(3,3), - out_dtype="int8") - func = relay.Function(relay.ir_pass.free_vars(quantized_output), - quantized_output) - print("###### Original graph starts ######") - print(func) - print("###### Original graph ends ######") - func = relay.ir_pass.infer_type(func) - print("###### TypeInferred graph starts ######") - print(func) - print("###### TypeInferred graph ends ######") - func = relay.quantize.quantize_rewrite(func) - func = relay.ir_pass.infer_type(func) - print("###### Lowered graph starts ######") - print(func) - print("###### Lowered graph ends ######") +# TODOs for janimesh before submitting this patch. +# TODO - Add tests for int8 input/weight dtype +# TODO - opt_level=0 fails mostly due to fusion. +# TODO - opt_level=3 fails, likely culprit kernel layout for int8 +# compute. Work with Rankyung to see if this is the culprit. Handle +# it in a separate patch. + + + +def test_quantized_convolution_op(): + def verify(func, goldens): + with relay.build_config(opt_level=0): + graph, lib, params = relay.build(func, "llvm", params=None) + golden_data, golden_weight, golden_st2_output = goldens + mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) + mod.set_input("quantized_data",golden_data) + mod.set_input("weight",golden_weight) + mod.set_input(**params) + mod.run() + res = mod.get_output(0).asnumpy() + np.testing.assert_equal(res, golden_st2_output) + + def get_func(data_shape, data_dtype, + weight_shape, weight_dtype, strides, out_dtype, use_int_compute): + input_scale = 0.25098 + kernel_scale = 0.501961 + output_scale = 0.501961 + kernel_size = (2, 2) + quantized_data = relay.var("quantized_data", shape=data_shape, + dtype=data_dtype) + quantized_weight = relay.var("weight", shape=weight_shape, + dtype=weight_dtype) + func = relay.op.qnn.conv2d( + quantized_data, quantized_weight, + input_zero_point=0, + kernel_zero_point=0, + kernel_size=kernel_size, + strides=strides, + out_dtype="int32", + data_layout="NCHW", + kernel_layout="OIHW") + func = relay.op.qnn.requantize( + func, + input_zero_point=0, + output_zero_point=0, + input_scale=input_scale*kernel_scale, + output_scale=output_scale, + out_dtype=out_dtype, + use_int_compute=use_int_compute) + + func = relay.Function(relay.ir_pass.free_vars(func), + func) + func = relay.ir_pass.infer_type(func) + func = relay.quantize.rewrite(func) + print(func) + return func + + def run_tests(): + def st1_basic_tests(): + # NCHW input + golden_data = np.array([2, 1, 4, 3, + 5, 4, 2, 3, + 3, 8, 4, 9, + 6, 10, 1, 2]).astype('uint8')\ + .reshape((2, 1, 2, 4)) + # OIHW weight + golden_weight = np.array([2, 4, 6, 8, + 4, 2, 6, 4, + 0, 4, 2, 8]).astype('uint8')\ + .reshape((3, 1, 2, 2)) + + golden_output = np.array([18, 15, 14, + 14, 11, 12, + 12, 10, 10, + 39, 25, 17, + 26, 26, 12, + 31, 11, 14]).astype('uint8')\ + .reshape((2, 3, 1, 3)) + + for use_int_compute in [True, False]: + func = get_func(data_shape=(2, 1, 2, 4), + data_dtype='uint8', + weight_shape=(3, 1, 2, 2), + weight_dtype='uint8', + strides=(1, 1), + out_dtype="uint8", + use_int_compute=use_int_compute) + + # Check the int8 input type as well + for use_int_compute in [True, False]: + func = get_func(data_shape=(2, 1, 2, 4), + data_dtype='int8', + weight_shape=(3, 1, 2, 2), + weight_dtype='int8', + strides=(1, 1), + out_dtype="uint8", + use_int_compute=use_int_compute) + verify(func, (golden_data, golden_weight, golden_output)) + + + def st2_basic_tests(): + # NCHW input + golden_data = np.array([2, 1, 4, 3, + 5, 4, 2, 3, + 3, 8, 4, 9, + 6, 10, 1, 2]).astype('uint8')\ + .reshape((2, 1, 2, 4)) + # OIHW weight + golden_weight = np.array([2, 4, 6, 8, + 4, 2, 6, 4, + 0, 4, 2, 8]).astype('uint8')\ + .reshape((3, 1, 2, 2)) + + + golden_output = np.array([18, 14, 14, + 12, 12, 10, + 39, 17, 26, + 12, 31, 14]).astype('uint8')\ + .reshape((2, 3, 1, 2)) + + for use_int_compute in [True, False]: + func = get_func(data_shape=(2, 1, 2, 4), + data_dtype='uint8', + weight_shape=(3, 1, 2, 2), + weight_dtype='uint8', + strides=(2, 2), + out_dtype="uint8", + use_int_compute=use_int_compute) + verify(func, (golden_data, golden_weight, golden_output)) + + def st2_saturating_tests(): + # Check the output of saturating test + golden_data = np.array([239, 239, 4, 3, + 6, 4, 2, 3, + 3, 8, 4, 9, + 6, 10, 1, 2]).astype('uint8')\ + .reshape((2, 1, 2, 4)) + # OIHW weight + golden_weight = np.array([2, 4, 6, 8, + 4, 2, 6, 4, + 0, 4, 2, 8]).astype('uint8')\ + .reshape((3, 1, 2, 2)) + + + # Check uint8 output clamping + golden_output = np.array([255, 14, 255, + 12, 251, 10, + 39, 17, 26, + 12, 31, 14]).astype('uint8')\ + .reshape((2, 3, 1, 2)) + for use_int_compute in [True, False]: + func = get_func(data_shape=(2, 1, 2, 4), + data_dtype='uint8', + weight_shape=(3, 1, 2, 2), + weight_dtype='uint8', + strides=(2, 2), + out_dtype="uint8", + use_int_compute=use_int_compute) + verify(func, (golden_data, golden_weight, golden_output)) + + # Check int8 output clamping + golden_output = np.array([127, 14, 127, + 12, 127, 10, + 39, 17, 26, + 12, 31, 14]).astype('uint8')\ + .reshape((2, 3, 1, 2)) + for use_int_compute in [True, False]: + func = get_func(data_shape=(2, 1, 2, 4), + data_dtype='uint8', + weight_shape=(3, 1, 2, 2), + weight_dtype='uint8', + strides=(2, 2), + out_dtype="int8", + use_int_compute=use_int_compute) + verify(func, (golden_data, golden_weight, golden_output)) + + + # Check that int16 does not clamp + golden_output = np.array([377, 14, 373, + 12, 251, 10, + 39, 17, 26, + 12, 31, 14]).astype('uint16')\ + .reshape((2, 3, 1, 2)) + for use_int_compute in [True, False]: + func = get_func(data_shape=(2, 1, 2, 4), + data_dtype='uint8', + weight_shape=(3, 1, 2, 2), + weight_dtype='uint8', + strides=(2, 2), + out_dtype="int16", + use_int_compute=use_int_compute) + verify(func, (golden_data, golden_weight, golden_output)) + + #def cast_test(): + # data = relay.var("data", shape=(1, 5), dtype="float32") # --> [63, -4, 76, 0, -5] + # data1 = relay.var("data1", shape=(1, 5), dtype="float32") + # data2 = relay.var("data2", shape=(1, 5), dtype="float32") + # relu1 = relay.op.nn.relu(data) + # relu2 = relay.op.nn.relu(data1) + # relu3 = relay.op.nn.relu(data2) + # func = relay.op.where(relu1, relu2, relu3) + # func = relay.op.nn.relu(func) + # print(func) + # func = relay.Function(relay.ir_pass.free_vars(func), + # func) + + + # with relay.build_config(opt_level=1): + # graph, lib, params = relay.build(func, "llvm", params=None) + # golden_data = np.array([[63, -4, 76, 0, -5]]) + # golden_data1 = np.array([[1, 2, 3, 4, 5]]) + # golden_data2 = np.array([[6, 7, 8, 9, 10]]) + # mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) + # mod.set_input("data",golden_data) + # mod.set_input("data1",golden_data1) + # mod.set_input("data2",golden_data2) + # mod.set_input(**params) + # mod.run() + # res = mod.get_output(0).asnumpy() + # print(res) + # #np.testing.assert_equal(res, golden_st2_output) + + + if __name__ == "__main__": + st1_basic_tests() + st2_basic_tests() + st2_saturating_tests() + # cast_test() + + run_tests() if __name__ == "__main__": - test_quantized_conv2d() + test_quantized_convolution_op()