Skip to content

Commit

Permalink
Comments
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Oct 28, 2020
1 parent 4c86a86 commit 2a9fce4
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 66 deletions.
21 changes: 8 additions & 13 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,22 +155,17 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
if (!IsEqualScalar(input_scale, output_scale)) {
int32_t fixed_point_multiplier, shift;
std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier);

const bool is_upward_rounding = (param->rounding == "UPWARD");

if (is_upward_rounding && fixed_point_multiplier == (1 << 30)) {
// Power of 2 is determined by the fixed_point_multiplier == 1 << 30. In case of power of 2,
// fixed point multiplier will represent a float value of 0.5. In fixed point, this is
// represented by 1 << 30.
scaled_int32_t = PowerOfTwoMultiply(scaled_int32_t, shift - 1);
} else {
// When using upward rounding (i.e., x.5 rounded to x+1), leverage
// the FixedPointMultiply operator
scaled_int32_t =
(is_upward_rounding
? FixedPointMultiply(scaled_int32_t, fixed_point_multiplier, shift)
: FixedPointMultiplyToNearest(scaled_int32_t, double_multiplier, input_shape));
}
// When using upward rounding (i.e., x.5 rounded to x+1), leverage
// the FixedPointMultiply operator
scaled_int32_t =
(is_upward_rounding
? FixedPointMultiply(scaled_int32_t, fixed_point_multiplier, shift)
: FixedPointMultiplyToNearest(scaled_int32_t, double_multiplier, input_shape));
}

} else {
// This is per-channel (per=axis) quantization.
std::vector<double> double_multipliers;
Expand Down
15 changes: 0 additions & 15 deletions src/relay/qnn/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,6 @@ std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplie
return std::make_pair(significand, exponent);
}

Expr PowerOfTwoMultiply(Expr tensor, int32_t exp) {
Expr out;
if (exp > 0) {
// power of 2 is greater than 0, apply left shift.
out = LeftShift(tensor, MakeConstantScalar(DataType::Int(32), exp));
} else {
// power of 2 is less than 0, round and then apply right shift.
exp = -exp;
auto rounding_factor = 1 << (exp - 1);
auto rounded_t = Add(tensor, MakeConstantScalar(DataType::Int(32), rounding_factor));
out = RightShift(rounded_t, MakeConstantScalar(DataType::Int(32), exp));
}
return out;
}

Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier,
const Array<IndexExpr>& input_shape) {
// Choose high precision datatype to be int64. This is for avoiding overflow
Expand Down
7 changes: 0 additions & 7 deletions src/relay/qnn/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,6 @@ static inline int64_t get_const_int(const tvm::PrimExpr& x) {
*/
Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier,
const Array<IndexExpr>& input_shape);
/*
* \brief Mutiply an integer datatype tensor by a power of two.
* \param tensor The quantized input tensor of dtype int32.
* \param exp The exp or the power of 2 representing the number to be multiplied.
* \return The sequence of Relay ops for power of two multiplication.
*/
Expr PowerOfTwoMultiply(Expr tensor, int32_t exp);

/*
* \brief Fixed point multiplication between integer tensor with floating point
Expand Down
90 changes: 59 additions & 31 deletions src/target/intrin_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,37 +128,65 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.q_multiply_shift")
PrimExpr q = call->args[2];
PrimExpr s = call->args[3];

// Only int32 types are supported (any number of lanes is allowed)
ICHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32);
ICHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits() == 32);

DataType hp_dtype = DataType::Int(64, x.dtype().lanes());
DataType lp_dtype = DataType::Int(32, x.dtype().lanes());

// 1) Calculating the integer multiplier and integer shift
PrimExpr zero = make_const(s.dtype(), 0);
PrimExpr left_shift = tir::Select(s > zero, s, zero);
PrimExpr right_shift = tir::Select(s > zero, zero, -s);

// 2) Cast and Multiply the integer multiplier
PrimExpr one = make_const(hp_dtype, 1);
x = cast(hp_dtype, x);
y = cast(hp_dtype, y);
x = tir::Select(left_shift != zero, x << left_shift, x);

// 3) Perform the multiplication in higher precision.
x = x * y;

// 4) Find the rounding scalar
PrimExpr total_right_shift = right_shift + q;
PrimExpr pos_rounding_value = (one << (total_right_shift - 1));
x = x + pos_rounding_value;

// 5) Simply right shift the result to get the final output.
x = x >> total_right_shift;

// 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32.
*rv = cast(lp_dtype, x);
// Lambda function to extract the int value from PrimExpr
auto get_int_value = [](const PrimExpr node) {
auto broadcast_node = node.as<BroadcastNode>();
CHECK(broadcast_node != nullptr);
auto int_node = broadcast_node->value.as<IntImmNode>();
CHECK(int_node != nullptr);
return int_node->value;
};
// Power of 2 is determined by the fixed_point_multiplier == 1 << 30. In case of power of 2,
// fixed point multiplier will represent a float value of 0.5. In fixed point, this is
// represented by 1 << 30.
if (get_int_value(y) == (1 << 30)) {
PrimExpr exp = s - 1;
int exp_val = get_int_value(s) - 1;
if (exp_val > 0) {
// power of 2 is greater than 0, apply left shift.
*rv = x << exp;
} else {
// power of 2 is less than 0, round and then apply right shift.
DataType lp_dtype = DataType::Int(32, x.dtype().lanes());
PrimExpr one = make_const(lp_dtype, 1);
exp = -exp;
PrimExpr rounding_factor = one << (exp - 1);
PrimExpr rounded_t = x + rounding_factor;
*rv = rounded_t >> exp;
}
} else {
// Only int32 types are supported (any number of lanes is allowed)
ICHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32);
ICHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits() == 32);

DataType hp_dtype = DataType::Int(64, x.dtype().lanes());
DataType lp_dtype = DataType::Int(32, x.dtype().lanes());

// 1) Calculating the integer multiplier and integer shift
PrimExpr zero = make_const(s.dtype(), 0);
PrimExpr left_shift = tir::Select(s > zero, s, zero);
PrimExpr right_shift = tir::Select(s > zero, zero, -s);

// 2) Cast and Multiply the integer multiplier
PrimExpr one = make_const(hp_dtype, 1);
x = cast(hp_dtype, x);
y = cast(hp_dtype, y);
x = tir::Select(left_shift != zero, x << left_shift, x);

// 3) Perform the multiplication in higher precision.
x = x * y;

// 4) Find the rounding scalar
PrimExpr total_right_shift = right_shift + q;
PrimExpr pos_rounding_value = (one << (total_right_shift - 1));
x = x + pos_rounding_value;

// 5) Simply right shift the result to get the final output.
x = x >> total_right_shift;

// 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32.
*rv = cast(lp_dtype, x);
}
});

} // namespace intrin
Expand Down

0 comments on commit 2a9fce4

Please sign in to comment.