Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OPT] Low-bit Quantization #2116

Merged
merged 5 commits into from
Jan 31, 2019
Merged

[OPT] Low-bit Quantization #2116

merged 5 commits into from
Jan 31, 2019

Conversation

ZihengJiang
Copy link
Contributor

Thanks for contributing to TVM! Please refer to guideline https://docs.tvm.ai/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from Reviewers.

@ajtulloch
Copy link
Contributor

This is all assuming a symmetric quantization scheme, correct? Have you considered generalizing this slightly to an asymmetric quantization scheme like the one used in GEMMLOWP, QNNPACK, FBGEMM, NNAPI, etc?

python/tvm/relay/_quantization.py Outdated Show resolved Hide resolved
@tqchen tqchen added the status: need RFC need RFC discussion label Nov 15, 2018
@tqchen
Copy link
Member

tqchen commented Nov 15, 2018

Since quantization is a major feature, it is better to send a RFC first

@ZihengJiang
Copy link
Contributor Author

I will propose a RFC next week. Thanks @ajtulloch @tqchen .

include/tvm/relay/op.h Outdated Show resolved Hide resolved
python/tvm/relay/build_module.py Outdated Show resolved Hide resolved
python/tvm/relay/quantize/quantize.py Outdated Show resolved Hide resolved
python/tvm/relay/quantize/quantize.py Outdated Show resolved Hide resolved
python/tvm/relay/quantize/quantize.py Outdated Show resolved Hide resolved
topi/python/topi/util.py Outdated Show resolved Hide resolved
python/tvm/relay/quantize/quantize_ops.py Outdated Show resolved Hide resolved
python/tvm/relay/quantize/quantize_ops.py Outdated Show resolved Hide resolved
src/relay/pass/forward_rewrite.cc Outdated Show resolved Hide resolved
src/relay/op/nn/convolution.cc Outdated Show resolved Hide resolved
python/tvm/relay/quantize/quantize.py Outdated Show resolved Hide resolved
src/relay/pass/pattern_util.h Outdated Show resolved Hide resolved
src/relay/pass/quantize.cc Outdated Show resolved Hide resolved
@ajtulloch
Copy link
Contributor

Has there been an RFC posted btw? This comment probably belongs there.

FWIW I'm a little concerned about some directions this PR is taking, or at least some use-cases that would be good to see handled that I don't see how they fit in currently.

For background on my perspective, a standard training flow for quantized models in TF/C2 (at least the fwk's I'm familiar with that implement this), is to:

  1. Implement a model in a standard ML framework, generally using fp16/bfloat16/fp32 compute precision as this has highest throughput on most commonly-used training hardware.
  2. (optionally) insert fake quantization (here, called simulated quantization) nodes at quantization boundaries (i.e. if your backend implements a fused Int8Conv + Int8Relu, you'd insert them after a Conv + Relu block), to simulate the quantization numerics at training time.
  3. Train the model as usual
  4. Implement a graph rewriting pass (i.e. TF's toco, C2's int8_converter, MXNet's quantization, etc) that rewrites the graph to target the int8 operators directly — i.e. remapping subgraphs of e.g. FP32Conv + FP32Relu to be a fused Int8ConvRelu operator. This requires computing output quantization parameters at requantization boundaries, which can be done either by
    • calibration to an example set of activations, via e.g. l-p norm or kl minimization (c2/tf/mxnet/tensorrt)
    • using activation ranges learned during training (c2/tf).
  5. Using this quantized graph, evaluate various metrics to verify the quantization-induced error/loss is acceptable.
  6. Deploy the quantized graph.

Does this workflow make sense to folks? If not, could folks please elaborate on where we differ?

Given this flow, we'd like to insert TVM into this process. One key use case that I'd like TVM to consider supporting is to allow frameworks to continue to use their existing approaches for Steps 1-5, and involve TVM in Step 6. There are several reasons for this, such as calibration-based quantization isn't always sufficient, and we'd like to supporting importing from existing int8 graph IRs like TFLite or C2.

I think requiring TVM to take on Steps 4 and 5 in order to implement quantized models is unnecessarily opinionated, and moves it towards being a fully-fledged framework in it's own right (which I thought was not the goal).

I would have thought one natural (and minimalistic) direction for TVM to support quantized models (which isn't precluded by this diff, but I want to see what folks think about this) would be something like:

  1. Implement (in topi) support for int8 ops (i.e. ((u)int8 inputs, int32 accumulation, int32 output). This is partially done already by the great work from folks in the community. If we generalize to asymmetric quantization (which IMO is quite important), then it's arguably more natural to represent the inputs/outputs as tuples of (uint8 tensor, float min, float max) or equivalently (uint8 tensor, int32 bias, float scale), and implement operators using this representation.
  2. Add some kind of requantize op in NNVM, that performs a int32 -> (u)int8 requantization with the appropriate output float min/float max obtained via calibration or training.
  3. Implement in nnvm frontend an importer for e.g. tflite models (which would mostly involve mapping ops like TFLiteConv into a nnvm::Conv + nnvm::Requantize sequence, and ensuring that TVM/NNVM fuse away sequences of requantize/pointwise/requantize), and demonstrate a) bitwise numerical equivalence, and b) speedups vs tflite's runtime for models like MobileNetV2 or similar.

Concretely, my concerns with this approach (assuming the goal is to be the 'the one true way' to execute quantized models in TVM) are that it a) integrates too early in the pipeline, which unnecessarily requires some assumptions, and b) these assumptions aren't the most general ones (i.e. requires symmetric quantization as used by e.g. MKLDNN), which precludes asymmetric quantization as in TF, TFLite, C2, GEMMLOWP, QNNPACK, and channel-wise quantization as in TF/C2 which is very useful for pushing bitwidths lower (see e.g. https://arxiv.org/pdf/1806.08342.pdf), and c) is less modular than other approaches, which makes it harder to target from existing frameworks that already support quantization.

I don't think our goals are in conflict, I just thought that I should put this on the radar. Happy to send out an RFC (and dedicate engineering effort) to the more alternative approach as well if folks are on board?

@tqchen
Copy link
Member

tqchen commented Dec 6, 2018

@ajtulloch an RFC need to be sent out and we won't merge the PR before the RFC get discussed, so we can move the discuss there after it get posted

@ZihengJiang
Copy link
Contributor Author

Hi @ajtulloch, I have a paper deadline so I pushed forward this PR in a hurry to get a workable quantization workflow. Let me send out a RFC tomorrow. This PR won't be merged before we have discussion in the community.

@tqchen
Copy link
Member

tqchen commented Dec 6, 2018

x

@lixiaoquan
Copy link
Contributor

Currently, it seems NNVM requires inputs of a op have same data type. But a quantization scheme may cause different types of inputs. Any suggestion about that?

@ajtulloch
Copy link
Contributor

@lixiaoquan there's no such requirement today AFAIK, it's user-controlled in the implementation of attr<FInferType>(..) for the relevant NNVM op.

@@ -213,3 +214,16 @@ def select_array(i, j):
return now

return tvm.compute(matrix.shape, select_array, name=name)


@tvm.register_func("print_tensor")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, maybe we can add as an util later as separate PR, but we need documents on these

python/tvm/relay/quantize/annotate_ops.py Outdated Show resolved Hide resolved
python/tvm/relay/quantize/annotate_ops.py Outdated Show resolved Hide resolved
src/relay/pass/quantize.cc Outdated Show resolved Hide resolved
python/tvm/relay/quantize/__init__.py Outdated Show resolved Hide resolved
tests/python/quantize/test_pass_quantize.py Outdated Show resolved Hide resolved
python/tvm/relay/quantize/annotate_ops.py Outdated Show resolved Hide resolved
python/tvm/relay/quantize/annotate_ops.py Outdated Show resolved Hide resolved
python/tvm/relay/quantize/quantize.py Outdated Show resolved Hide resolved
python/tvm/relay/quantize/quantize.py Outdated Show resolved Hide resolved
python/tvm/relay/quantize/quantize.py Outdated Show resolved Hide resolved
python/tvm/relay/quantize/quantize.py Outdated Show resolved Hide resolved
src/relay/pass/pattern_util.h Outdated Show resolved Hide resolved
@vinx13

This comment has been minimized.

@ZihengJiang ZihengJiang changed the title [WIP] Low-bit Quantization [OPT] Low-bit Quantization Dec 18, 2018
@ZihengJiang
Copy link
Contributor Author

ZihengJiang commented Jan 17, 2019

@liangfu Thanks for catching this outdated test

@@ -124,7 +124,7 @@ def _bind_params_by_name(func, params):
return expr.bind(func, bind_dict)


def optimize(func, target, params=None):
def optimize(func, target=None, params=None):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems this API changes recently? It breaks some codes @tqchen

@ZihengJiang
Copy link
Contributor Author

Here is an evaluation script: https://gist.github.com/ZihengJiang/bcabe46a712a417a01a6967d4430b6b5

@eqy @vinx13 @liangfu

@tqchen
Copy link
Member

tqchen commented Jan 18, 2019

@antinucleon @hlu1 @anijain2305 please also help take a look when you have time

@eqy
Copy link
Contributor

eqy commented Jan 18, 2019

@ZihengJiang sorry this is basic question, but is there support for mixed quantization levels? It looks like currently we specify a global weight and activation precision only. Since we can already skip the first k conv layers, it seems that this would be a useful generalization.

Copy link
Contributor

@eqy eqy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

@ZihengJiang
Copy link
Contributor Author

@eqy User can override the rewrite function to implement mix-precision quantization. But it is not included in this pr

@vinx13
Copy link
Member

vinx13 commented Jan 25, 2019

In resnet, we use int32 for residual addition. But I found saving intermediate int32 results to global memory is much slower, is it possible to use int8 in this case (we need to mofidy annotate of add)? I'm not sure the impact to the model precision.

@ZihengJiang ZihengJiang merged commit 741b6bb into apache:master Jan 31, 2019
@ZihengJiang ZihengJiang added status: accepted and removed status: need review status: need update need update based on feedbacks labels Feb 2, 2019
libing4752 pushed a commit to libing4752/tvm that referenced this pull request Feb 18, 2019
* [QUANTIZE] Quantization implementation.

* Update.

* Update.

* Update.

* Update.
merrymercy pushed a commit to merrymercy/tvm that referenced this pull request Feb 18, 2019
* [QUANTIZE] Quantization implementation.

* Update.

* Update.

* Update.

* Update.
wweic pushed a commit to neo-ai/tvm that referenced this pull request Feb 20, 2019
* [QUANTIZE] Quantization implementation.

* Update.

* Update.

* Update.

* Update.
wweic pushed a commit to neo-ai/tvm that referenced this pull request Feb 20, 2019
* [QUANTIZE] Quantization implementation.

* Update.

* Update.

* Update.

* Update.
@yzhliu yzhliu mentioned this pull request Mar 2, 2019
28 tasks
@YiranCdr
Copy link

YiranCdr commented Aug 4, 2020

hey guys, I'm wondering whether or not TVM support any INT16 quantization? If the answer is yes, is it quantization aware training or post-training quantization? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.