diff --git a/include/tvm/relay/attrs/qnn.h b/include/tvm/relay/attrs/qnn.h index 12afe19d26b3e..cf69fa759c1c0 100644 --- a/include/tvm/relay/attrs/qnn.h +++ b/include/tvm/relay/attrs/qnn.h @@ -37,6 +37,7 @@ struct RequantizeAttrs : public tvm::AttrsNode { double output_scale; int32_t output_zero_point; bool use_int_compute; + std::string rounding_mode; DataType out_dtype; TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") { @@ -48,14 +49,22 @@ struct RequantizeAttrs : public tvm::AttrsNode { .describe("The scale of the input tensor."); TVM_ATTR_FIELD(output_scale) .describe("The scale of the output tensor."); - TVM_ATTR_FIELD(use_int_compute).set_default(false) - .describe("When true, the integer computation is used to handle output scale"); + TVM_ATTR_FIELD(use_int_compute).set_default(true) + .describe("When true, the integer computation is used to handle output scale." + "The float compuation can be used as reference implementation or in" + "cases where FP32 computation for requantize is not expensive"); TVM_ATTR_FIELD(out_dtype) .set_default(NullValue()) .describe("Output data type, set to explicit type under mixed precision setting"); + TVM_ATTR_FIELD(rounding_mode).set_default("FE_UPWARD") + .describe("Defines the rounding direction when the value is midway between" + "two representable values. There are two supported modes - FE_UPWARD" + "or FE_AWAY_FROM_ZERO. More context can be found at" + "https://www.gnu.org/software/libc/manual/html_node/Rounding.html"); } }; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_NN_QUANTIZE_H_ diff --git a/python/tvm/relay/op/qnn/qnn.py b/python/tvm/relay/op/qnn/qnn.py index 18be68cd9cfc4..484b3864f22fb 100644 --- a/python/tvm/relay/op/qnn/qnn.py +++ b/python/tvm/relay/op/qnn/qnn.py @@ -19,9 +19,9 @@ from __future__ import absolute_import as _abs from . import _make - def requantize(input_data, input_zero_point, input_scale, output_zero_point, - output_scale, out_dtype="int32", use_int_compute=False): + output_scale, out_dtype="int32", use_int_compute=False, + rounding_mode="FE_UPWARD"): r"""Requantized operator. The requantize operator converts one quantized tensor to another quantized @@ -57,11 +57,18 @@ def requantize(input_data, input_zero_point, input_scale, output_zero_point, use_int_compute : bool, optional Use fully integer computation for requantizing. + rounding_mode : string, optional + Defines the rounding direction when the value is midway between two + representable values. + Returns ------- result : tvm.relay.Expr The computed result. """ + assert rounding_mode in ("FE_UPWARD", "FE_AWAY_FROM_ZERO"),\ + "Unsupported rounding mode" + return _make.requantize(input_data, input_zero_point, input_scale, output_zero_point, output_scale, out_dtype, - use_int_compute) \ No newline at end of file + use_int_compute, rounding_mode) diff --git a/src/relay/op/nn/requantize.cc b/src/relay/op/nn/requantize.cc index 80f2bde4ad472..285528993f6f8 100644 --- a/src/relay/op/nn/requantize.cc +++ b/src/relay/op/nn/requantize.cc @@ -59,7 +59,8 @@ Expr MakeRequantize(Expr data, int32_t output_zero_point, double output_scale, DataType out_dtype, - bool use_int_compute) { + bool use_int_compute, + std::string rounding_mode) { auto attrs = make_node(); attrs->out_dtype = std::move(out_dtype); attrs->input_zero_point = std::move(input_zero_point); @@ -67,6 +68,7 @@ Expr MakeRequantize(Expr data, attrs->input_scale = std::move(input_scale); attrs->output_scale = std::move(output_scale); attrs->use_int_compute = std::move(use_int_compute); + attrs->rounding_mode = std::move(rounding_mode); static const Op& op = Op::Get("qnn.requantize"); return CallNode::make(op, {data}, Attrs(attrs), {}); } diff --git a/src/relay/pass/quantize_rewrite.cc b/src/relay/pass/quantize_rewrite.cc index 55f8c43fd49fc..645b20c0730e1 100644 --- a/src/relay/pass/quantize_rewrite.cc +++ b/src/relay/pass/quantize_rewrite.cc @@ -33,13 +33,27 @@ namespace tvm { namespace relay { - // Lowering of qnn.requantize op + +/* + * Converts a floating point number so that it can be represented by integers. + * The representation is + * float_number = (fixed_point_multiplier) * 2^(shift) + * + * The fixed_point_multiplier is a number between 0.5 and 1. This is represented + * by an integer number. For example, if it is int32, then the decimal point + * exists between bit 31 and 30 from LSB (or between first and second bit from + * the left). + * + * Some examples are + * 0.25 = (0.5) * 2^(-1) + * 0.125 = (0.5) * 2^(-2) + */ void GetFixedPointMultiplierShift(double double_multiplier, int32_t* fixed_point_multiplier, int* shift, const DataType& idtype) { - int acc_dtype_bits = idtype.bits(); + int idtype_bits = idtype.bits(); if (double_multiplier == 0.) { *fixed_point_multiplier = 0; @@ -47,9 +61,9 @@ void GetFixedPointMultiplierShift(double double_multiplier, return; } const double q = std::frexp(double_multiplier, shift); - auto q_fixed = static_cast(std::round(q * (1ll << (acc_dtype_bits - 1)))); - CHECK_LE(q_fixed, (1ll << (acc_dtype_bits - 1))); - if (q_fixed == (1ll << (acc_dtype_bits - 1))) { + auto q_fixed = static_cast(std::round(q * (1ll << (idtype_bits - 1)))); + CHECK_LE(q_fixed, (1ll << (idtype_bits - 1))); + if (q_fixed == (1ll << (idtype_bits - 1))) { q_fixed /= 2; ++*shift; } @@ -57,85 +71,6 @@ void GetFixedPointMultiplierShift(double double_multiplier, *fixed_point_multiplier = static_cast(q_fixed); } -Expr MultiplyByIntegerMuliplier(const Expr& convolved_tensor, - const int32_t fixed_point_multiplier, const int left_shift, - const RequantizeAttrs*& param, const DataType& idtype, - const Array& out_shape) { - // TODO (janimesh) - How to add the overflow checks here. TFLite code snippet is - // bool overflow = a == b && a == std::numeric_limits::min(); - // return overflow ? std::numeric_limits::max() : .....;/ - - // The calculations are done in upcast of idtype to retain precision. - int acc_dtype_bits = idtype.bits(); - DataType up_idtype = Int(2 * acc_dtype_bits); - - auto tensor = convolved_tensor; - // Typically the left_shift will be 0 if the original scale is > 0.5. - if (left_shift != 0) { - tensor = Multiply(tensor, MakeConstantScalar(idtype, 1 << left_shift)); - } - - // Upcast the computation to Int64 and multiply the multiplier. - Expr scalar = MakeConstantScalar(up_idtype, fixed_point_multiplier); - auto multiplied_t = Multiply(Cast(tensor, up_idtype), scalar); - - // Since, we are performing fixed point computation. We are only interested in - // higher 16/32 bits. But before that, we also need to perform rounding. - // This is fixed point rounding. So, the rounder add scalar depends if the - // input is positive. - auto zero = MakeConstantScalar(up_idtype, 0); - auto pos_threshold = MakeConstantScalar(up_idtype, - 1ll << (acc_dtype_bits - 2)); - auto neg_threshold = MakeConstantScalar(up_idtype, - (1 - (1ll << (acc_dtype_bits - 2)))); - auto pos_rounder = Full(pos_threshold, out_shape, up_idtype); - auto neg_rounder = Full(neg_threshold, out_shape, up_idtype); - auto rounding_scalar = Where(GreaterEqual(multiplied_t, zero), pos_rounder, neg_rounder); - auto rounded_tensor = Add(multiplied_t, rounding_scalar); - - // Perform right shift to get the first 16/32 bits. - // The result is first doubled and the first 15/31 bits are obtained. This is - // done by just right shifting the result by 15/31 bits. - auto right_shift_scalar = MakeConstantScalar(up_idtype, (acc_dtype_bits - 1)); - auto scaled_t = RightShift(rounded_tensor, right_shift_scalar); - auto q_imin = get_qmin(idtype); - auto q_imax = get_qmax(idtype); - auto integer_multiplied_t = Cast(Clip(scaled_t, q_imin, q_imax), - idtype); - return integer_multiplied_t; -} - -Expr ShiftByIntegerShift(const Expr& multiplied_t, - const int& exponent, const RequantizeAttrs*& param, - const DataType& idtype, const Array& out_shape) { - CHECK_GE(exponent, 0); - int acc_dtype_bits = idtype.bits(); - CHECK_LE(exponent, (acc_dtype_bits - 1)); - - // We need to perform rounding. The rounding here is closest to the power - // of 2. The exponent basically represents the decimal point. We need to round - // at the decimal point. - auto tensor = multiplied_t; - if (exponent != 0) { - auto pos_rounder = MakeConstantScalar(idtype, (1ll << (exponent - 1))); - auto neg_rounder = MakeConstantScalar(idtype, (1ll << (exponent - 1)) - 1); - auto pos_rounder_t = Full(pos_rounder, out_shape, idtype); - auto neg_rounder_t = Full(neg_rounder, out_shape, idtype); - - auto zero = MakeConstantScalar(idtype, 0); - auto zero_t = Full(zero, out_shape, idtype); - auto round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, - neg_rounder_t); - tensor = Add(tensor, round_scalar); - } - - // Right shift by exponent to approximate the division. - auto scaled_t = RightShift(tensor, - MakeConstantScalar(idtype, exponent)); - return scaled_t; -} - - /* * Requantization using only integer computation. Here, the computation is * converted to a fixed point computation by computing output multiplier and @@ -147,59 +82,123 @@ Expr ShiftByIntegerShift(const Expr& multiplied_t, * multiplication with an int value and then right shifting the result. This * approximates the floating point computation with a fixed point computation. * - * The whole computaition this can be broken down into following steps + * The whole computation this can be broken down into following steps * 1) Calculate the integer multiplier and integer shift. - * 2) Multiply the integer multiplier with quantized tensor. - * 3) Right shift the result. + * 2) Subtract the input integer point. + * 2) Multiply the integer fixed point multiplier with quantized tensor. + * 3) Round the result. + * 4) Right shift the result. + * 5) Add the output_zero_point. + * 6) Cast to the out_dtype. * - * The only thing complicating the above computations is the tedious approach of - * handling rounding. */ -Expr RequantizeInt(const Expr& convolved_tensor, +Expr RequantizeInt(const Expr& input_tensor, const RequantizeAttrs*& param, const DataType& idtype, const Array& out_shape) { double double_multiplier = param->input_scale/param->output_scale; + + // The multiplication will be performed in higher precision. Find the dtype. + int idtype_bits = idtype.bits(); + DataType up_idtype = Int(2 * idtype_bits); + // 1) Calculating the integer multiplier and integer shift int32_t fixed_point_multiplier; int shift; GetFixedPointMultiplierShift(double_multiplier, &fixed_point_multiplier, &shift, idtype); - - // 2) Multiply the integer multiplier int left_shift = shift > 0 ? shift : 0; int right_shift = shift > 0 ? 0 : -shift; - auto multiplied_t = MultiplyByIntegerMuliplier(convolved_tensor, - fixed_point_multiplier, left_shift, param, idtype, out_shape); - // 3) Divide by the denominator or right shift the result. - auto scaled_int32_t = ShiftByIntegerShift(multiplied_t, - right_shift, param, idtype, out_shape); + // 2) Subtract the input_zero_point + auto tensor = input_tensor; + tensor = Cast(tensor, up_idtype); + if (param->input_zero_point != 0) { + auto input_zp = MakeConstantScalar(up_idtype, param->input_zero_point); + tensor = Subtract(tensor, input_zp); + } - // 4) Clip to the out_dtype min/max. + + + // 3) Multiply the integer multiplier + if (left_shift != 0) { + tensor = Multiply(tensor, MakeConstantScalar(up_idtype, 1 << left_shift)); + } + // Perform the multiplication in higher precision. + // If idtype is Int(32), the scalar is a fixed point value of int32 where the + // decimal point is between bits 31 and 30. After multiplying with + // input_tensor, the result in int64 where the decimal point is sitting + // between bits 31 and 30 (from the right, rightmost bit is bit 0). + Expr scalar = MakeConstantScalar(up_idtype, fixed_point_multiplier); + auto multiplied_t = Multiply(tensor, scalar); + + + // 4) Find the rounding scalar. This depends on where the final decimal point + // sits. As we will be right shifting the multiplied_t, we need to first + // calculate the totol_right_shift. + int total_right_shift = right_shift + idtype_bits - 1; + + tensor = multiplied_t; + Expr round_scalar; + if (param->rounding_mode == "FE_UPWARD") { + auto pos_rounder = MakeConstantScalar(up_idtype, (1ll << (total_right_shift - 1))); + round_scalar = pos_rounder; + } else if (param->rounding_mode == "FE_AWAY_FROM_ZERO") { + auto pos_rounder = MakeConstantScalar(up_idtype, (1ll << (total_right_shift - 1))); + auto neg_rounder = MakeConstantScalar(up_idtype, (1ll << (total_right_shift - 1)) - 1); + auto pos_rounder_t = Full(pos_rounder, out_shape, up_idtype); + auto neg_rounder_t = Full(neg_rounder, out_shape, up_idtype); + + auto zero = MakeConstantScalar(up_idtype, 0); + auto zero_t = Full(zero, out_shape, up_idtype); + round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, + neg_rounder_t); + } + // Add the rounding scalar. + tensor = Add(tensor, round_scalar); + + // 5) Simply right shift the result to get the final output. + auto scaled_int64_t = RightShift(tensor, + MakeConstantScalar(up_idtype, total_right_shift)); + + // 6) Add the output zero point. + auto output_zp = MakeConstantScalar(up_idtype, param->output_zero_point); + auto shifted_int64_t = Add(output_zp, scaled_int64_t); + + // 7) Clip to the out_dtype min/max. + // Find the right clip min/maxes. While clipping, it is necessary that + // clip_min and clip_max are within the dtype range of the input tensor to the + // clip operator. For example, if the input to clip operator is int8, but the + // out_dtype is uint8, we will get incorrect results, if we set max as 255. auto q_min = std::max(get_qmin(param->out_dtype), get_qmin(idtype)); auto q_max = std::min(get_qmax(param->out_dtype), get_qmax(idtype)); - auto clipped_t = Clip(scaled_int32_t, q_min, q_max); + auto clipped_t = Clip(shifted_int64_t, q_min, q_max); auto requantized_output = Cast(clipped_t, param->out_dtype); return requantized_output; } -/* + +/* * Requantization using floating computation. Here we can multiply the scale to - * the convolved_tensor, round to nearest integer and then cast back to int32. + * the input_tensor, round to nearest integer and then cast back to int32. */ -Expr RequantizeFloat(const Expr& convolved_tensor, +Expr RequantizeFloat(const Expr& input_tensor, const RequantizeAttrs*& param, const DataType& idtype, const Array& out_shape) { double double_multiplier = param->input_scale/param->output_scale; auto scalar_multiplier = MakeConstantScalar(Float(32), double_multiplier); - - // Multiply the convolved tensor with the new scale. - auto casted_t = Cast(convolved_tensor, Float(32)); - auto multiplied_t = Round(Multiply(casted_t, scalar_multiplier)); + auto input_zp = MakeConstantScalar(idtype, param->input_zero_point); + auto output_zp = MakeConstantScalar(Float(32), param->output_zero_point); + + // Multiply the tensor with the new scale. + auto shifted_input_t = Subtract(input_tensor, input_zp); + auto casted_t = Cast(shifted_input_t, Float(32)); + auto multiplied_t = Multiply(casted_t, scalar_multiplier); + auto shifted_multiplied_t = Add(output_zp, multiplied_t); + auto rounded_t = Round(shifted_multiplied_t); auto q_imin = get_qmin(idtype); auto q_imax = get_qmax(idtype); - auto scaled_int32_t = Cast(Clip(multiplied_t, q_imin, q_imax), + auto scaled_int32_t = Cast(Clip(rounded_t, q_imin, q_imax), idtype); // Clip to the out_dtype min/max. @@ -243,14 +242,6 @@ Expr RequantizeForwardRewrite(const Call& ref_call, << " Please run infer_type pass."; const auto input_dtype = input_tt->dtype; - // Check for current quantization support. - CHECK_EQ(param->input_zero_point, 0) - << "Encountered non-zero zero point." - << " Only symmetric quantization supported for now."; - CHECK_EQ(param->output_zero_point, 0) - << "Encountered non-zero zero point." - << " Only symmetric quantization supported for now."; - if (param->use_int_compute) { return RequantizeInt(quantized_data, param, input_dtype, out_shape); } else { @@ -258,18 +249,14 @@ Expr RequantizeForwardRewrite(const Call& ref_call, } } - RELAY_REGISTER_OP("qnn.requantize") .set_attr("FQuantizeForwardRewrite", RequantizeForwardRewrite); - - TVM_REGISTER_API("relay._quantize.rewrite") .set_body_typed([](const Expr& e) { - Expr ret = ForwardRewrite(e, "FQuantizeForwardRewrite", nullptr, nullptr); - return ret; -}); - + Expr ret = ForwardRewrite(e, "FQuantizeForwardRewrite", nullptr, nullptr); + return ret; + }); } // namespace relay } // namespace tvm diff --git a/tests/python/unittest/test_quantized_ops.py b/tests/python/unittest/test_quantized_ops.py new file mode 100644 index 0000000000000..e70ea09252313 --- /dev/null +++ b/tests/python/unittest/test_quantized_ops.py @@ -0,0 +1,257 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import numpy as np +from tvm import relay +from tvm.relay.testing import create_workload +from tvm.contrib import graph_runtime + +rounding_modes = ["FE_UPWARD", "FE_AWAY_FROM_ZERO"] + +def run_infer_type(expr): + mod = relay.Module.from_expr(expr) + mod = relay.transform.InferType()(mod) + entry = mod["main"] + return entry if isinstance(expr, relay.Function) else entry.body + + +def test_requantize(): + def verify(func, goldens): + with relay.build_config(opt_level=0): + graph, lib, params = relay.build(func, "llvm", params=None) + golden_data, golden_output = goldens + mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) + mod.set_input("quantized_data",golden_data) + mod.set_input(**params) + mod.run() + res = mod.get_output(0).asnumpy() + np.testing.assert_equal(res, golden_output) + + def get_func(data_shape, data_dtype, out_dtype, use_int_compute, + rounding_mode, input_scale, output_scale, input_zero_point=0, + output_zero_point=0): + quantized_data = relay.var("quantized_data", shape=data_shape, + dtype=data_dtype) + func = relay.op.qnn.requantize( + quantized_data, + input_zero_point=input_zero_point, + output_zero_point=output_zero_point, + input_scale=input_scale, + output_scale=output_scale, + rounding_mode=rounding_mode, + out_dtype=out_dtype, + use_int_compute=use_int_compute) + + func = relay.Function(relay.analysis.free_vars(func), + func) + func = run_infer_type(func) + func = relay.quantize.rewrite(func) + print(func) + return func + + + def run_tests(): + def same_scale_test(): + # Have same scales, everything within range + golden_data = np.arange(-100, 100, 1).astype('int32') + golden_output = golden_data + + for rounding_mode in rounding_modes: + for use_int_compute in [True, False]: + func = get_func(data_shape=(200, ), + data_dtype='int32', + out_dtype="int8", + use_int_compute=use_int_compute, + rounding_mode=rounding_mode, + input_scale=0.5, + output_scale=0.5) + verify(func, (golden_data, golden_output)) + + def downscale_test(): + for rounding_mode in rounding_modes: + for use_int_compute in [True, False]: + func = get_func(data_shape=(32, ), + data_dtype='int32', + out_dtype="int32", + use_int_compute=use_int_compute, + rounding_mode=rounding_mode, + input_scale=1, + output_scale=16) + + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype('int32') + golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + verify(func, (golden_data, golden_output)) + + # Try negative values + # -8 corresponds to -0.5. For FE_UPWARD, this is 0 + golden_data = np.arange(0, -32, -1).astype('int32') + if use_int_compute == True and rounding_mode == "FE_UPWARD": + golden_output = np.repeat([0, -1, -2], [9, 16, 7]) + else: + golden_output = np.repeat([0, -1, -2], [8, 16, 8]) + verify(func, (golden_data, golden_output)) + + # Try a different scale + for use_int_compute in [True, False]: + func = get_func(data_shape=(32, ), + data_dtype='int32', + out_dtype="int8", + use_int_compute=use_int_compute, + rounding_mode=rounding_mode, + input_scale=1, + output_scale=4) + + # Try positive values + # 2I corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype('int32') + golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8], + [2, 4, 4, 4, 4, 4, 4, 4, 2]) + verify(func, (golden_data, golden_output)) + + # Try negative values + # -8 corresponds to -0.5. For FE_UPWARD, this is 0 + golden_data = np.arange(0, -32, -1).astype('int32') + if use_int_compute == True and rounding_mode == "FE_UPWARD": + golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8], + [3, 4, 4, 4, 4, 4, 4, 4, 1]) + else: + golden_output = np.repeat([0, -1, -2, -3, -4, -5, -6, -7, -8], + [2, 4, 4, 4, 4, 4, 4, 4, 2]) + verify(func, (golden_data, golden_output)) + + def upscale_test(): + for rounding_mode in rounding_modes: + for use_int_compute in [True, False]: + func = get_func(data_shape=(32, ), + data_dtype='int32', + out_dtype="int8", + use_int_compute=use_int_compute, + rounding_mode=rounding_mode, + input_scale=2, + output_scale=1) + + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype('int32') + golden_output = np.multiply(2, golden_data) + verify(func, (golden_data, golden_output)) + + # Try negative values + # -8 corresponds to -0.5. For FE_UPWARD, this is 0 + golden_data = np.arange(0, -32, -1).astype('int32') + golden_output = np.multiply(2, golden_data) + verify(func, (golden_data, golden_output)) + + def saturation_test(): + for rounding_mode in rounding_modes: + for use_int_compute in [True, False]: + func = get_func(data_shape=(16, ), + data_dtype='int32', + out_dtype="int8", + use_int_compute=use_int_compute, + rounding_mode=rounding_mode, + input_scale=0.5, + output_scale=0.5) + golden_data = np.arange(0, 16, 1).astype('int32') + golden_data = np.add(120, golden_data) + output = np.array([120, 121, 122, 123, 124, 125, 126, 127, + 127, 127, 127, 127, 127, 127, 127, 127]) + golden_output = output + verify(func, (golden_data, golden_output)) + + # Try negative numbers + golden_data = np.arange(0, -16, -1).astype('int32') + golden_data = np.add(-120, golden_data) + output = np.array([-120, -121, -122, -123, -124, -125, -126, -127, + -128, -128, -128, -128, -128, -128, -128, -128]) + golden_output = output + verify(func, (golden_data, golden_output)) + + def zero_point_test(): + # Output zero point + for rounding_mode in rounding_modes: + for use_int_compute in [True, False]: + func = get_func(data_shape=(32, ), + data_dtype='int32', + out_dtype="int32", + use_int_compute=use_int_compute, + rounding_mode=rounding_mode, + input_scale=1, + output_scale=16, + output_zero_point=1) + + # Try positive values + # 8 corresponds to 0.5, resulting in 1 + golden_data = np.arange(0, 32, 1).astype('int32') + golden_output = np.repeat([0, 1, 2], [8, 16, 8]) + golden_output = np.add(1, golden_output) + verify(func, (golden_data, golden_output)) + + # Try negative values + # -8 corresponds to -0.5. For FE_UPWARD, this is 0 + golden_data = np.arange(-32, -64, -1).astype('int32') + if use_int_compute == True and rounding_mode == "FE_UPWARD": + golden_output = np.repeat([-2, -3, -4], [9, 16, 7]) + else: + golden_output = np.repeat([-2, -3, -4], [8, 16, 8]) + golden_output = np.add(1, golden_output) + verify(func, (golden_data, golden_output)) + + # Input zero point + for rounding_mode in rounding_modes: + for use_int_compute in [True, False]: + func = get_func(data_shape=(32, ), + data_dtype='int32', + out_dtype="int32", + use_int_compute=use_int_compute, + rounding_mode=rounding_mode, + input_scale=1, + output_scale=16, + input_zero_point=16) + + # Try positive values + golden_data = np.arange(32, 64, 1).astype('int32') + golden_output = np.repeat([2, 3, 4], [8, 16, 8]) + golden_output = np.subtract(golden_output, 1) + verify(func, (golden_data, golden_output)) + + # Try negative values + golden_data = np.arange(-32, -64, -1).astype('int32') + if use_int_compute == True and rounding_mode == "FE_UPWARD": + golden_output = np.repeat([-2, -3, -4], [9, 16, 7]) + else: + golden_output = np.repeat([-2, -3, -4], [8, 16, 8]) + golden_output = np.subtract(golden_output, 1) + verify(func, (golden_data, golden_output)) + + + + + if __name__ == "__main__": + same_scale_test() + downscale_test() + upscale_test() + saturation_test() + zero_point_test() + + run_tests() + +if __name__ == "__main__": + test_requantize()