-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[QNN] Requantize operator #3531
Conversation
Let us distinguish the qnn from the quantize namespace, given that the primary lowering phase supports in |
@tqchen Thanks for the quick comment :) Can you please elaborate a little bit? Do you mean to change this file - https://github.com/dmlc/tvm/pull/3531/files#diff-dcae58edc1986609a54e3a202ee49b3b |
python/tvm/relay/quantize/rewrite.py
Outdated
@@ -0,0 +1,37 @@ | |||
# Licensed to the Apache Software Foundation (ASF) under one |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tqchen Does it make sense to move this to newly created directory - python/tvm/relay/op/qnn
?
This will ensure the quantize
python namespace is not polluted.
src/relay/pass/quantize_rewrite.cc
Outdated
@@ -0,0 +1,262 @@ | |||
/* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tqchen Should l call this file qnn_rewrite.cc to have a clear marking between quantize
(Quantization in Relay) and qnn (creating new quantized ops) work?
@FrozenGene Please have a look. |
src/relay/pass/quantize_rewrite.cc
Outdated
} | ||
|
||
RELAY_REGISTER_OP("qnn.requantize") | ||
.set_attr<FForwardRewrite>("FQuantizeForwardRewrite", RequantizeForwardRewrite); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tqchen Does it make sense to replace FQuantizeForwardRewrite with FQnnForwardRewrite (for differentiating qnn and quantize)?
Given that qnn is a dialect. We could consider giving it a special name place and folder to live. e.g.
We can debate on whether it should be relay/qnn/op or relay/op/qnn cc @jroesch Perhaps let us open an RFC for dialect convention. I have not yet taken a deep look into the PR, but here are a few high-level comments
|
I see. Makes sense. I don't have a strong preference for any of the namespaces that you suggested. My favorite is src/relay/qnn. In this case, we are only dealing with ops. But, I can believe that some other dialect might want to more than just ops. But, I am ok with any of the above options. Thanks for the high-level comments. I will fix the utils and put them in the src/relay/qnn directory. Will look for CamelCase. (I was hoping that lint would do that for me, but I guess not). |
@tqchen I have updated the code. The C++ files are in src/relay/qnn directory. It has its own sub-directories like include, op, pass to serve different purposes. Let me know how it looks. |
python/tvm/relay/op/qnn/qnn.py
Outdated
from . import _make | ||
|
||
def requantize(input_data, input_zero_point, input_scale, output_zero_point, | ||
output_scale, out_dtype="int32", use_int_compute=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The following relay attr's default value is true
. Please unify the default value (We should set it be True
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. Thanks :)
python/tvm/relay/op/qnn/qnn.py
Outdated
|
||
def requantize(input_data, input_zero_point, input_scale, output_zero_point, | ||
output_scale, out_dtype="int32", use_int_compute=False, | ||
rounding_mode="FE_UPWARD"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TFLite requires "round_awar_from_zero". Could you explain why we set it "FE_UPWARD" be the default value? Because of MXNet's requirement?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can be "round_away_from_zero". These two rounding modes provide two points in the performance accuracy tradeoff curve. FE_UPWARD has fewer operators compared to round_way_from_zero. But, I think you are right, we should set "round_away_from_zero" by default. If somebody wants better performance, they can set it to FE_UPWARD using a network-level config.
// input_tensor, the result in int64 where the decimal point is sitting | ||
// between bits 31 and 30 (from the right, rightmost bit is bit 0). | ||
Expr scalar = MakeConstantScalar(up_idtype, fixed_point_multiplier); | ||
auto multiplied_t = Multiply(tensor, scalar); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is overflow
logic? If it is overflow
, we return int32_max.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will look into this. I remember seeing an overflow check. Thanks for pointing out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@FrozenGene Thought little more about it. I think we don't need this check. Please let me know if this makes sense.
The relevant portion from TFLite is
bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min();
Here a is input tensor and b is the fixed_point_multiplier. The fixed_point_multiplier is calculated using frexp.
Now, frexp gives negative fixed point multiplier only when the input floating point number is negative. In our case, this floating point number is input_scale/output_scale
. I think the scales are always positive, so we should never have negative number. (Let me know if I am wrong)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made a pass on it. It is mainly style stuff.
python/tvm/relay/op/qnn/__init__.py
Outdated
# specific language governing permissions and limitations | ||
# under the License. | ||
# pylint: disable=wildcard-import | ||
"""Neural network related operators.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quantized neural network?
src/relay/pass/pattern_util.h
Outdated
} | ||
|
||
inline Expr Full(Expr fill_value, | ||
Array<IndexExpr> shape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
align
src/relay/qnn/include/attrs.h
Outdated
} | ||
}; | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove one blank line
src/relay/qnn/include/util.h
Outdated
|
||
/*! | ||
* \file tvm/relay/quantize_util.h | ||
* \brief Utility methods needs for quantized ops that can be shared |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
needed?
src/relay/qnn/op/requantize.cc
Outdated
RELAY_REGISTER_OP("qnn.requantize") | ||
.describe(R"code(Requantize operator. | ||
|
||
FIXME |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix what?
|
||
// 2) Subtract the input_zero_point | ||
auto tensor = input_tensor; | ||
tensor = Cast(tensor, up_idtype); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto tensor = Cast(input_tensor, up_idtype);
?
auto input_zp = MakeConstantScalar(up_idtype, param->input_zero_point); | ||
tensor = Subtract(tensor, input_zp); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove blank lines
Expr scalar = MakeConstantScalar(up_idtype, fixed_point_multiplier); | ||
auto multiplied_t = Multiply(tensor, scalar); | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove one blank line
|
||
// 4) Find the rounding scalar. This depends on where the final decimal point | ||
// sits. As we will be right shifting the multiplied_t, we need to first | ||
// calculate the totol_right_shift. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
total_right_shift
|
||
|
||
|
||
if __name__ == "__main__": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move xx_test under main
below directly? run_test
looks not necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about separate them to individual test_
functions instead of packing together to one giant test_requantize
?
return func | ||
|
||
|
||
def run_tests(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add uint8
test. Like TFLite, whose out dtype is uint8
. Let us cover it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Will add it by tonight.
src/relay/qnn/util.h
Outdated
namespace tvm { | ||
namespace relay { | ||
|
||
inline bool IsInt8(const DataType& dtype) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is certainly a too small function to introduce another level of indirection. Directly remove it and keep things in the original condition
src/relay/qnn/util.h
Outdated
*/ | ||
|
||
/*! | ||
* \file tvm/relay/quantize_util.h |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wrong file name
src/relay/qnn/util.h
Outdated
|| IsInt16(dtype) || IsUint16(dtype); | ||
} | ||
|
||
enum class QuantizeOpType : uint8_t { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unless there is a need(data compactness), you don't want to use uint8_t as storage type
src/relay/qnn/util.h
Outdated
return dtype == Float(32); | ||
} | ||
|
||
inline bool IsQuantizedType(const DataType& dtype) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IsQNNDataType
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename to qnn_lower.cc
python/tvm/relay/qnn/op/qnn.py
Outdated
from __future__ import absolute_import as _abs | ||
from . import _make | ||
|
||
def requantize(input_data, input_zero_point, input_scale, output_zero_point, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please debate on API naming, especially refer to possible existing API names and choices of options
I made a few quick reviews. @FrozenGene @zhiics @antinucleon @hlu1 please help check again if you have time. Let us specifically spend some time to discuss the API, in particular, what are the choices of parameter names, order and the options, do they make sense or not. Is there existing APIs that we can be consistent with etc. |
include/tvm/relay/qnn/attrs.h
Outdated
.describe("Defines the rounding direction when the value is midway between" | ||
"two representable values. There are two supported modes - FE_UPWARD" | ||
"or FE_AWAY_FROM_ZERO. More context can be found at" | ||
"https://www.gnu.org/software/libc/manual/html_node/Rounding.html"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest to explain the meaning of the two modes here as well, so to make it self-explained.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, added the description.
python/tvm/relay/qnn/op/qnn.py
Outdated
|
||
Parameters | ||
---------- | ||
quantized_data : tvm.relay.Expr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doc does not match the arguments in the function
src/relay/qnn/util.h
Outdated
} else if (IsUint32(dtype)) { | ||
return std::numeric_limits<uint32_t>::min(); | ||
} | ||
LOG(FATAL) << "Type not supported\n"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also << dtype to make debug more easier. also I think we can remove "\n"
RELAY_REGISTER_OP("qnn.requantize") | ||
.set_attr<FForwardRewrite>("FQuantizeForwardRewrite", RequantizeForwardRewrite); | ||
|
||
TVM_REGISTER_API("relay._qnn.rewrite") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be ported to pass manager? @zhiics
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Talked to @anijain2305 offline, it sounds this pass is more or less a dialect that is only performed in the frontend parsers. I think we probably don't want to bring it into the pass manager because 1) it is quite frontend specific, 2) it only does Expr
-> Expr
transformation.
We could probably use better naming instead of rewrite
here and ir_pass
in Python. Since we have pass
under the dialect fold, should we have transform
as well?
@tqchen thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think we need to use transform
const Array<Expr>& new_args, const NodeRef& ctx) { | ||
CHECK_EQ(new_args.size(), 1); | ||
Expr quantized_data = new_args[0]; | ||
const auto* param = ref_call->attrs.as<RequantizeAttrs>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if param == nullptr ?
src/relay/qnn/util.h
Outdated
return IsFloat32(in_dtype) || IsQuantizedType(in_dtype); | ||
case QuantizeOpType ::Dequantize: | ||
return IsQuantizedType(in_dtype); | ||
case QuantizeOpType ::Requantize: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove space before ::
|
||
|
||
|
||
if __name__ == "__main__": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about separate them to individual test_
functions instead of packing together to one giant test_requantize
?
Can you also look at #3512 from API name perspective as well. |
Let me write down my initial pass over the API design. Please help me refine it further. Order of parameters
Parameter namesI could not find a specific pattern in the names. However, I think these are some common rules
Current proposalAdhering to above rules, the API looks like this
|
Thanks for all the comments. I fixed them. Please review once again. |
Can we also list the related APIs from existing frameworks e.g. tflite so we can get a good reference. Right now I can understand most of the parameters, except use_int_domain was a bit unclear to me |
For TF, the requantize operator looks like this
I think all the above arguments are represented in current API proposal (min/max are represented as scale/zero_point).
However, if you think we will never use FP32 computation for requantize, that is also a reasonable answer. |
Maybe we don't need |
Thanks @FrozenGene and @tqchen. Spent some more time thinking about it. Realized that it might not be necessary. Removed Please review again. |
@tqchen Please let me know if you got a chance to look at the changes. By the way, phisiart was added for review automatically when I was rebasing. Maybe, I made a mistake in rebasing. I don't know how to remove the reviewer. |
@tqchen Can we try to get this merged if everything looks good? Will open up discussion/review for many other PRs. |
* [Relay] [Quantization] WIP - Common files for the qauntization work. * [Relay] [Quantization] WIP - Prototyping requantize op. * Requantize operator implementation. Requantize converts one quantized tensor representation to another quantized representation. The PR has following implementation features - Requantize operator defined in qnn namespace - relay.qnn.requantize - Lowering of the requantize to exisiting Relay operators - Integer fixed point implementation of requantize - Two rounding modes - FE_UPWARDS (round towards infinity) and FE_AWAY_FROM_ZERO (std::round behavior) - Floating point implementation as well, that can act as reference or can be used for devices when FP32 computation is not used. - Unit test cases Relevant Issue - apache#2351 Credit to TFLite and GemmLowp to provide reference implementations. * Typo and lint fixes. * Doc fix. * Uncommenting the lint script (fixing mistake). * Modifying the unit tests. * Moving C++ files into src/relay/qnn * Moving python files to python/tvm/relay/qnn. Some minor fixes. * Moving the attrs.h inside the include directory. * Pushing files that I forgot earlier. Changing util location. * Incorporating comments. API change. Lint fixes. * Modifying the GetFixedPointMultiplierShift API as per comments. * Forgot the dialect change. * Changing rewrite to qnn_lower. * Renaming Quantize to Qnn for clarity. * Remove use_int_domain. * Incorportaing review comments. * Adding API doc for QNN dialect. * Move the qnn_lower pass to transform namespace. * Moving from expr to module. Adding namespace in C++. * Minor sentence rewrites. Added qnn namespace. * Added the API doc. * Chanding default out_dtype to int8. Adding a test with in/out_dtype as uint8. * Style fixes. Better error messages. * Adding documentation. * More documentation fixes. * Adding out dtype check for requantize. * Adding corner case for FP32 to fixed point conversion. * Adding extra line. * Documentation fix. * Adding static inline. * Incorporating jackwish comment. Removed idtype from requantize lowering. * Removing Quantize/Dequantize code. Restricting Requantize to (u)int8/int32. * Style fixes. * Fix the docs. * Move to Legalize API.
* [Relay] [Quantization] WIP - Common files for the qauntization work. * [Relay] [Quantization] WIP - Prototyping requantize op. * Requantize operator implementation. Requantize converts one quantized tensor representation to another quantized representation. The PR has following implementation features - Requantize operator defined in qnn namespace - relay.qnn.requantize - Lowering of the requantize to exisiting Relay operators - Integer fixed point implementation of requantize - Two rounding modes - FE_UPWARDS (round towards infinity) and FE_AWAY_FROM_ZERO (std::round behavior) - Floating point implementation as well, that can act as reference or can be used for devices when FP32 computation is not used. - Unit test cases Relevant Issue - apache#2351 Credit to TFLite and GemmLowp to provide reference implementations. * Typo and lint fixes. * Doc fix. * Uncommenting the lint script (fixing mistake). * Modifying the unit tests. * Moving C++ files into src/relay/qnn * Moving python files to python/tvm/relay/qnn. Some minor fixes. * Moving the attrs.h inside the include directory. * Pushing files that I forgot earlier. Changing util location. * Incorporating comments. API change. Lint fixes. * Modifying the GetFixedPointMultiplierShift API as per comments. * Forgot the dialect change. * Changing rewrite to qnn_lower. * Renaming Quantize to Qnn for clarity. * Remove use_int_domain. * Incorportaing review comments. * Adding API doc for QNN dialect. * Move the qnn_lower pass to transform namespace. * Moving from expr to module. Adding namespace in C++. * Minor sentence rewrites. Added qnn namespace. * Added the API doc. * Chanding default out_dtype to int8. Adding a test with in/out_dtype as uint8. * Style fixes. Better error messages. * Adding documentation. * More documentation fixes. * Adding out dtype check for requantize. * Adding corner case for FP32 to fixed point conversion. * Adding extra line. * Documentation fix. * Adding static inline. * Incorporating jackwish comment. Removed idtype from requantize lowering. * Removing Quantize/Dequantize code. Restricting Requantize to (u)int8/int32. * Style fixes. * Fix the docs. * Move to Legalize API.
Requantize converts one quantized tensor representation to another quantized
representation. The PR has following implementation features
Relevant Issue - #2351
FE_AWAY_FROM_ZERO (std::round behavior)
used for devices when FP32 computation is not expensive.
Credit to TFLite and GemmLowp to provide reference implementations.
Thanks for contributing to TVM! Please refer to guideline https://docs.tvm.ai/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from Reviewers.